用于数据科学的几种Python装饰器介绍 - Bytepawn


在这篇文章中,我将展示一些@decorators可能对数据科学家有用的东西:

@parallel
让我们假设我写了一个非常低效的方法来寻找素数:

from sympy import isprime

def generate_primes(domain: int=1000*1000, num_attempts: int=1000) -> list[int]:
    primes: set[int] = set()
    seed(time())
    for _ in range(num_attempts):
        candidate: int = randint(4, domain)
        if isprime(candidate):
            primes.add(candidate)
    return sorted(primes)

print(len(generate_primes()))

输出:88

然后我意识到,如果我在所有的CPU线程上并行运行原来的generate_primes(),我可以得到一个 "免费 "的加速。这是很常见的,定义一个@parallel用法:

def parallel(func=None, args=(), merge_func=lambda x:x, parallelism = cpu_count()):
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            results = Parallel(n_jobs=parallelism)(delayed(func)(*args, **kwargs) for i in range(parallelism))
            return merge_func(results)
        return inner
    if func is None:
        # decorator was used like @parallel(...)
        return decorator
    else:
        # decorator was used like @parallel, without parens
        return decorator(func)

有了这个,只需一行,我们就可以将我们的函数并行化。

@parallel(merge_func=lambda li: sorted(set(chain(*li))))
def generate_primes(...): # same signature, nothing changes
    ... # same code, nothing changes

print(len(generate_primes()))

输出:1281

在我的例子中,我的Macbook有8个核心,16个线程(cpu_count()是16),所以我产生了16倍的素数。

注意:
唯一的开销是必须定义一个merge_func,它将函数的不同运行结果合并为一个结果,以便向装饰函数(本例中为 generate_primes())的外部调用者隐藏并行性。在这个玩具例子中,我只是合并了列表,并通过使用 set() 确保素数是唯一的。
有许多Python库和方法(例如线程与进程)可以实现并行。
这个例子使用了joblib.Parallel()的进程并行,它在Darwin + python3 + ipython上运行良好,并且避免了对Python全局解释器锁(GIL)的锁定。

@production
有时候,我们写了一个复杂的管道,有一些额外的步骤,我们只想在某些环境下运行。例如,在我们的本地开发环境中做一些事情,但在生产环境中不做,反之亦然。如果能够对函数进行装饰,让它们只在某些环境下运行,而在其他地方不做任何事情,那就更好了。

实现这一目标的方法之一是使用一些简单的装饰器。@production表示我们只想在prod上运行的东西,@development表示我们只想在dev中运行的东西,我们甚至可以引入一个@inactive,将函数完全关闭。这种方法的好处是,这种方式可以在代码/Github中跟踪部署历史和当前状态。另外,我们可以在一行中做出这些改变,从而使提交更简洁;例如,@inactive比整个代码块被注释掉的大提交要干净。

production_servers = [...]

def production(func: Callable):
    def inner(*args, **kwargs):
        if gethostname() in production_servers:
            return func(*args, **kwargs)
        else:
            print('This host is not a production server, skipping function decorated with @production...')
    return inner

def development(func: Callable):
    def inner(*args, **kwargs):
        if gethostname() not in production_servers:
            return func(*args, **kwargs)
        else:
            print('This host is a production server, skipping function decorated with @development...')
    return inner

def inactive(func: Callable):
    def inner(*args, **kwargs):
        print('Skipping function decorated with @inactive...')
    return inner

@production
def foo():
    print('Running in production, touching databases!')

foo()

@development
def foo():
    print('Running in production, touching databases!')

foo()

@inactive
def foo():
    print('Running in production, touching databases!')

foo()

输出:

Running in production, touching databases!
This host is a production server, skipping function decorated with @development...
Skipping function decorated with @inactive...

这个想法可以适用于其他框架/环境。

@deployable
在我目前的工作中,我们使用Airflow进行ETL/数据管道。我们有一个丰富的辅助函数库,可以在内部构建适当的DAG,所以用户(数据科学家)不必担心这个问题。

最常用的是dag_vertica_create_table_as(),它在我们的Vertica DWH上运行一个SELECT,每晚将结果转储到一个表中。

dag = dag_vertica_create_table_as(
    table='my_aggregate_table',
    owner='Marton Trencseni (marton.trencseni@maf.ae)',
    schedule_interval='@daily',
    ...
    select="""
    SELECT
        ...
    FROM
        ...
   
"""
)

然后这就变成了对DWH的查询,大致是这样:
CREATE TABLE my_aggregate_table AS
SELECT ...


实际上,情况更复杂:我们首先运行今天的查询,如果今天的查询被成功创建,则有条件地删除昨天的查询。这个条件逻辑(以及其他一些针对我们环境的意外的复杂性,比如必须发布GRANTs)导致DAG有9个步骤,但这不是这里的重点,也超出了本文的范围。

在过去的两年里,我们已经创建了近500个DAG,所以我们扩大了Airflow EC2实例的规模,并引入了独立的开发和生产环境。如果能有一种方法来标记DAG是应该在开发环境还是生产环境中运行,在代码/Github中跟踪这一点,并使用相同的机制来确保DAG不会意外地运行在错误的环境中,那就更好了。

大约有10个类似的便利函数,如dag_vertica_create_or_replace_view_as()和dag_vertica_train_predict_model()等,我们希望这些dag_xxx()函数的所有调用都可以在生产和开发之间切换(或者到处跳过)。

然而,上一节中的@production和@development装饰器在这里不起作用,因为我们不想将dag_vertica_create_table_as()切换为永远不在其中一个环境中运行。我们希望能够在每次调用时进行设置,并且在我们所有的dag_xxxx()函数中都有这个功能,而不需要复制/粘贴代码。我们想要的是在我们所有的dag_xxxx()函数中添加一个部署参数(有一个好的默认值),这样我们就可以在我们的DAG中添加这个参数,以增加安全性。我们可以通过@deployable装饰器来实现这个目标。

def deployable(func):
    def inner(*args, **kwargs):
        if 'deploy' in kwargs:
            if kwargs['deploy'].lower() in ['production', 'prod'] and gethostname() not in production_servers:
                print('This host is not a production server, skipping...')
                return
            if kwargs['deploy'].lower() in ['development', 'dev'] and gethostname() not in development_servers:
                print('This host is not a development server, skipping...')
                return
            if kwargs['deploy'].lower() in ['skip', 'none']:
                print('Skipping...')
                return
            del kwargs['deploy'] # to avoid func() throwing an unexpected keyword exception
        return func(*args, **kwargs)
    return inner

然后,我们可以将装饰器添加到我们的函数定义中(每个函数添加1行)。

@deployable
def dag_vertica_create_table_as(...): # same signature, nothing changes
    ... # code signature, nothing changes

@deployable
def dag_vertica_create_or_replace_view_as(...): # same signature, nothing changes
    ... # code signature, nothing changes

@deployable
def dag_vertica_train_predict_model(...): # same signature, nothing changes
    ... # code signature, nothing changes

如果我们在这里停止,什么也不会发生,我们不会破坏任何东西。
然而,现在我们可以到我们使用这些函数的DAG文件中,增加1行。

dag = dag_vertica_create_table_as(
    deploy='development', # the function will return None on production
    ...
)

@redirect (stdout)
有时我们写一个大的函数,也会调用其他代码,各种信息都会被打印()出来。或者,我们可能有一个bug,有一堆print(),想在打印出来的内容上加上行号,这样就可以更容易地参考它们。在这些情况下,@redirect可能是有用的。这个装饰器将print()的标准输出重定向到我们自己的逐行打印机,我们可以对它做任何我们想做的事情(包括扔掉它)。

def redirect(func=None, line_print: Callable = None):
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            with StringIO() as buf, redirect_stdout(buf):
                func(*args, **kwargs)
                output = buf.getvalue()
            lines = output.splitlines()
            if line_print is not None:
                for line in lines:
                    line_print(line)
            else:
                width = floor(log(len(lines), 10)) + 1
                for i, line in enumerate(lines):
                    i += 1
                    print(f'{i:0{width}}: {line}')
        return inner
    if func is None:
        # decorator was used like @redirect(...)
        return decorator
    else:
        # decorator was used like @redirect, without parens
        return decorator(func)


如果我们使用redirect()而不指定明确的line_print()函数,它就会打印行数,但要加上行号。

@redirect
def print_lines(num_lines):
    for i in range(num_lines):
        print(f'Line #{i+1}')

print_lines(10)

Output:

01: Line #1
02: Line #2
03: Line #3
04: Line #4
05: Line #5
06: Line #6
07: Line #7
08: Line #8
09: Line #9
10: Line #10


如果我们想把所有的打印文本保存到一个变量中,我们也可以实现这一点。

lines = []
def save_lines(line):
    lines.append(line)

@redirect(line_print=save_lines)
def print_lines(num_lines):
    for i in range(num_lines):
        print(f'Line #{i+1}')

print_lines(3)
print(lines)


Output:

['Line #1', 'Line #2', 'Line #3']

重定向stdout的实际工作是由contextlib.redirect_stdout完成的。


@stacktrace
下一个装饰器模式是@stacktrace,当函数被调用和从函数返回值时,它会发出有用的信息。

def stacktrace(func=None, exclude_files=['anaconda']):
    def tracer_func(frame, event, arg):
        co = frame.f_code
        func_name = co.co_name
        caller_filename = frame.f_back.f_code.co_filename
        if func_name == 'write':
            return # ignore write() calls from print statements
        for file in exclude_files:
            if file in caller_filename:
                return # ignore in ipython notebooks
        args = str(tuple([frame.f_locals[arg] for arg in frame.f_code.co_varnames]))
        if args.endswith(',)'):
            args = args[:-2] + ')'
        if event == 'call':
            print(f'--> Executing: {func_name}{args}')
            return tracer_func
        elif event == 'return':
            print(f'--> Returning: {func_name}{args} -> {repr(arg)}')
        return
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            settrace(tracer_func)
            func(*args, **kwargs)
            settrace(None)
        return inner
    if func is None:
        # decorator was used like @stacktrace(...)
        return decorator
    else:
        # decorator was used like @stacktrace, without parens
        return decorator(func)


有了这个,我们就可以装饰我们希望追踪开始的最上面的函数,我们将得到关于分支的有用的输出。

def b():
    print('...')

@stacktrace
def a(arg):
    print(arg)
    b()
    return 'world'

a('foo')
Output:

--> Executing: a('foo')
foo
--> Executing: b()
...
--> Returning: b() -> None
--> Returning: a('foo') -> 'world'

这里唯一的诀窍是。在我的例子中,我在Anaconda上的ipython中运行这段代码,所以我隐藏了代码在路径中有Anaconda的文件中的部分调用栈(否则我在上面的片段中会得到大约50-100个无用的调用栈条目)。这是通过装饰器的exclude_files参数完成的。


@traceclass
与上述类似,我们可以定义一个装饰器@traceclass,与类一起使用,以获得其成员的执行轨迹。这包括在之前的装饰器帖子中,在那里它只是被称为@trace,并且有一个bug(在原来的帖子中已经修复)。这个装饰器。

def traceclass(cls: type):
    def make_traced(cls: type, method_name: str, method: Callable):
        def traced_method(*args, **kwargs):
            print(f'--> Executing: {cls.__name__}::{method_name}()')
            return method(*args, **kwargs)
        return traced_method
    for name in cls.__dict__.keys():
        if callable(getattr(cls, name)) and name != '__class__':
            setattr(cls, name, make_traced(cls, name, getattr(cls, name)))
    return cls


使用:

@traceclass
class Foo:
    i: int = 0
    def __init__(self, i: int = 0):
        self.i = i
    def increment(self):
        self.i += 1
    def __str__(self):
        return f'This is a {self.__class__.__name__} object with i = {self.i}'

f1 = Foo()
f2 = Foo(4)
f1.increment()
print(f1)
print(f2)
Output:

--> Executing: Foo::__init__()
--> Executing: Foo::__init__()
--> Executing: Foo::increment()
--> Executing: Foo::__str__()
This is a Foo object with i = 1
--> Executing: Foo::__str__()
This is a Foo object with i = 4