treequeues: 为pytree对象提供高性能的队列


如果您使用 jax 并且需要在进程之间传递一些 pytree,我可能会为您提供一些东西:)
我开发了一个“树队列”。它是为 pytree 的嵌套数组创建的队列。
传输速度比普通队列快10倍。这是通过利用共享内存阵列和避免pickling数据来完成的。这在开发分布式架构时非常有用,例如,速度是最重要的分布式强化学习。
在我的例子中,这个实现对于在实现 PBT 算法时消除瓶颈非常有用!

这个库包包含了使用pytree和multiprocessing.Arrays在进程之间传输数组和嵌套数组的队列。与vanilla multiprocessing.Queue相比,这个实现可以达到高达10倍的速度,这取决于树的形状和大小以及涉及的进程数量。

通过使用numpy数组与多进程数组的缓冲,数据可以在不需要腌制的情况下发送。
缺点之一是总大小(嵌套数组的大小与队列的最大大小)需要预先分配。

这个包包含TreeQueue和ArrayQueue,在这两种情况下,创建队列时需要传递一个数据实例和最大尺寸。

这个资源库的灵感来自portugueslab的ArrayQueues。

点击标题