在这篇文章中,我将展示一些@decorators可能对数据科学家有用的东西:
@parallel
让我们假设我写了一个非常低效的方法来寻找素数:
from sympy import isprime def generate_primes(domain: int=1000*1000, num_attempts: int=1000) -> list[int]: primes: set[int] = set() seed(time()) for _ in range(num_attempts): candidate: int = randint(4, domain) if isprime(candidate): primes.add(candidate) return sorted(primes) print(len(generate_primes())) |
输出:88
然后我意识到,如果我在所有的CPU线程上并行运行原来的generate_primes(),我可以得到一个 "免费 "的加速。这是很常见的,定义一个@parallel用法:
def parallel(func=None, args=(), merge_func=lambda x:x, parallelism = cpu_count()): def decorator(func: Callable): def inner(*args, **kwargs): results = Parallel(n_jobs=parallelism)(delayed(func)(*args, **kwargs) for i in range(parallelism)) return merge_func(results) return inner if func is None: # decorator was used like @parallel(...) return decorator else: # decorator was used like @parallel, without parens return decorator(func) |
有了这个,只需一行,我们就可以将我们的函数并行化。
@parallel(merge_func=lambda li: sorted(set(chain(*li)))) def generate_primes(...): # same signature, nothing changes ... # same code, nothing changes print(len(generate_primes())) |
输出:1281
在我的例子中,我的Macbook有8个核心,16个线程(cpu_count()是16),所以我产生了16倍的素数。
注意:
唯一的开销是必须定义一个merge_func,它将函数的不同运行结果合并为一个结果,以便向装饰函数(本例中为 generate_primes())的外部调用者隐藏并行性。在这个玩具例子中,我只是合并了列表,并通过使用 set() 确保素数是唯一的。
有许多Python库和方法(例如线程与进程)可以实现并行。
这个例子使用了joblib.Parallel()的进程并行,它在Darwin + python3 + ipython上运行良好,并且避免了对Python全局解释器锁(GIL)的锁定。
@production
有时候,我们写了一个复杂的管道,有一些额外的步骤,我们只想在某些环境下运行。例如,在我们的本地开发环境中做一些事情,但在生产环境中不做,反之亦然。如果能够对函数进行装饰,让它们只在某些环境下运行,而在其他地方不做任何事情,那就更好了。
实现这一目标的方法之一是使用一些简单的装饰器。@production表示我们只想在prod上运行的东西,@development表示我们只想在dev中运行的东西,我们甚至可以引入一个@inactive,将函数完全关闭。这种方法的好处是,这种方式可以在代码/Github中跟踪部署历史和当前状态。另外,我们可以在一行中做出这些改变,从而使提交更简洁;例如,@inactive比整个代码块被注释掉的大提交要干净。
production_servers = [...] def production(func: Callable): def inner(*args, **kwargs): if gethostname() in production_servers: return func(*args, **kwargs) else: print('This host is not a production server, skipping function decorated with @production...') return inner def development(func: Callable): def inner(*args, **kwargs): if gethostname() not in production_servers: return func(*args, **kwargs) else: print('This host is a production server, skipping function decorated with @development...') return inner def inactive(func: Callable): def inner(*args, **kwargs): print('Skipping function decorated with @inactive...') return inner @production def foo(): print('Running in production, touching databases!') foo() @development def foo(): print('Running in production, touching databases!') foo() @inactive def foo(): print('Running in production, touching databases!') foo() |
输出:
Running in production, touching databases! This host is a production server, skipping function decorated with @development... Skipping function decorated with @inactive... |
这个想法可以适用于其他框架/环境。
@deployable
在我目前的工作中,我们使用Airflow进行ETL/数据管道。我们有一个丰富的辅助函数库,可以在内部构建适当的DAG,所以用户(数据科学家)不必担心这个问题。
最常用的是dag_vertica_create_table_as(),它在我们的Vertica DWH上运行一个SELECT,每晚将结果转储到一个表中。
dag = dag_vertica_create_table_as( table='my_aggregate_table', owner='Marton Trencseni (marton.trencseni@maf.ae)', schedule_interval='@daily', ... select=""" SELECT ... FROM ... """ ) |
然后这就变成了对DWH的查询,大致是这样:
CREATE TABLE my_aggregate_table AS SELECT ... |
实际上,情况更复杂:我们首先运行今天的查询,如果今天的查询被成功创建,则有条件地删除昨天的查询。这个条件逻辑(以及其他一些针对我们环境的意外的复杂性,比如必须发布GRANTs)导致DAG有9个步骤,但这不是这里的重点,也超出了本文的范围。
在过去的两年里,我们已经创建了近500个DAG,所以我们扩大了Airflow EC2实例的规模,并引入了独立的开发和生产环境。如果能有一种方法来标记DAG是应该在开发环境还是生产环境中运行,在代码/Github中跟踪这一点,并使用相同的机制来确保DAG不会意外地运行在错误的环境中,那就更好了。
大约有10个类似的便利函数,如dag_vertica_create_or_replace_view_as()和dag_vertica_train_predict_model()等,我们希望这些dag_xxx()函数的所有调用都可以在生产和开发之间切换(或者到处跳过)。
然而,上一节中的@production和@development装饰器在这里不起作用,因为我们不想将dag_vertica_create_table_as()切换为永远不在其中一个环境中运行。我们希望能够在每次调用时进行设置,并且在我们所有的dag_xxxx()函数中都有这个功能,而不需要复制/粘贴代码。我们想要的是在我们所有的dag_xxxx()函数中添加一个部署参数(有一个好的默认值),这样我们就可以在我们的DAG中添加这个参数,以增加安全性。我们可以通过@deployable装饰器来实现这个目标。
def deployable(func): def inner(*args, **kwargs): if 'deploy' in kwargs: if kwargs['deploy'].lower() in ['production', 'prod'] and gethostname() not in production_servers: print('This host is not a production server, skipping...') return if kwargs['deploy'].lower() in ['development', 'dev'] and gethostname() not in development_servers: print('This host is not a development server, skipping...') return if kwargs['deploy'].lower() in ['skip', 'none']: print('Skipping...') return del kwargs['deploy'] # to avoid func() throwing an unexpected keyword exception return func(*args, **kwargs) return inner |
然后,我们可以将装饰器添加到我们的函数定义中(每个函数添加1行)。
@deployable def dag_vertica_create_table_as(...): # same signature, nothing changes ... # code signature, nothing changes @deployable def dag_vertica_create_or_replace_view_as(...): # same signature, nothing changes ... # code signature, nothing changes @deployable def dag_vertica_train_predict_model(...): # same signature, nothing changes ... # code signature, nothing changes |
如果我们在这里停止,什么也不会发生,我们不会破坏任何东西。
然而,现在我们可以到我们使用这些函数的DAG文件中,增加1行。
dag = dag_vertica_create_table_as( deploy='development', # the function will return None on production ... ) |
@redirect (stdout)
有时我们写一个大的函数,也会调用其他代码,各种信息都会被打印()出来。或者,我们可能有一个bug,有一堆print(),想在打印出来的内容上加上行号,这样就可以更容易地参考它们。在这些情况下,@redirect可能是有用的。这个装饰器将print()的标准输出重定向到我们自己的逐行打印机,我们可以对它做任何我们想做的事情(包括扔掉它)。
def redirect(func=None, line_print: Callable = None): def decorator(func: Callable): def inner(*args, **kwargs): with StringIO() as buf, redirect_stdout(buf): func(*args, **kwargs) output = buf.getvalue() lines = output.splitlines() if line_print is not None: for line in lines: line_print(line) else: width = floor(log(len(lines), 10)) + 1 for i, line in enumerate(lines): i += 1 print(f'{i:0{width}}: {line}') return inner if func is None: # decorator was used like @redirect(...) return decorator else: # decorator was used like @redirect, without parens return decorator(func) |
如果我们使用redirect()而不指定明确的line_print()函数,它就会打印行数,但要加上行号。
@redirect def print_lines(num_lines): for i in range(num_lines): print(f'Line #{i+1}') print_lines(10) Output: 01: Line #1 02: Line #2 03: Line #3 04: Line #4 05: Line #5 06: Line #6 07: Line #7 08: Line #8 09: Line #9 10: Line #10 |
如果我们想把所有的打印文本保存到一个变量中,我们也可以实现这一点。
lines = [] def save_lines(line): lines.append(line) @redirect(line_print=save_lines) def print_lines(num_lines): for i in range(num_lines): print(f'Line #{i+1}') print_lines(3) print(lines) Output: <p class="indent">['Line #1', 'Line #2', 'Line #3'] |
重定向stdout的实际工作是由contextlib.redirect_stdout完成的。
@stacktrace
下一个装饰器模式是@stacktrace,当函数被调用和从函数返回值时,它会发出有用的信息。
def stacktrace(func=None, exclude_files=['anaconda']): def tracer_func(frame, event, arg): co = frame.f_code func_name = co.co_name caller_filename = frame.f_back.f_code.co_filename if func_name == 'write': return # ignore write() calls from print statements for file in exclude_files: if file in caller_filename: return # ignore in ipython notebooks args = str(tuple([frame.f_locals[arg] for arg in frame.f_code.co_varnames])) if args.endswith(',)'): args = args[:-2] + ')' if event == 'call': print(f'--> Executing: {func_name}{args}') return tracer_func elif event == 'return': print(f'--> Returning: {func_name}{args} -> {repr(arg)}') return def decorator(func: Callable): def inner(*args, **kwargs): settrace(tracer_func) func(*args, **kwargs) settrace(None) return inner if func is None: # decorator was used like @stacktrace(...) return decorator else: # decorator was used like @stacktrace, without parens return decorator(func) |
有了这个,我们就可以装饰我们希望追踪开始的最上面的函数,我们将得到关于分支的有用的输出。
def b(): print('...') @stacktrace def a(arg): print(arg) b() return 'world' a('foo') Output: --> Executing: a('foo') foo --> Executing: b() ... --> Returning: b() -> None --> Returning: a('foo') -> 'world' |
这里唯一的诀窍是。在我的例子中,我在Anaconda上的ipython中运行这段代码,所以我隐藏了代码在路径中有Anaconda的文件中的部分调用栈(否则我在上面的片段中会得到大约50-100个无用的调用栈条目)。这是通过装饰器的exclude_files参数完成的。
@traceclass
与上述类似,我们可以定义一个装饰器@traceclass,与类一起使用,以获得其成员的执行轨迹。这包括在之前的装饰器帖子中,在那里它只是被称为@trace,并且有一个bug(在原来的帖子中已经修复)。这个装饰器。
def traceclass(cls: type): def make_traced(cls: type, method_name: str, method: Callable): def traced_method(*args, **kwargs): print(f'--> Executing: {cls.__name__}::{method_name}()') return method(*args, **kwargs) return traced_method for name in cls.__dict__.keys(): if callable(getattr(cls, name)) and name != '__class__': setattr(cls, name, make_traced(cls, name, getattr(cls, name))) return cls 使用: @traceclass class Foo: i: int = 0 def __init__(self, i: int = 0): self.i = i def increment(self): self.i += 1 def __str__(self): return f'This is a {self.__class__.__name__} object with i = {self.i}' f1 = Foo() f2 = Foo(4) f1.increment() print(f1) print(f2) Output: --> Executing: Foo::__init__() --> Executing: Foo::__init__() --> Executing: Foo::increment() --> Executing: Foo::__str__() This is a Foo object with i = 1 --> Executing: Foo::__str__() This is a Foo object with i = 4 |