Java中神经网络Triton GPU编程


在本文中,我们将介绍如何使用代码反射在 Java 中实现 Triton 编程模型,以替代 Python。

代码反射(Code Reflection)是 OpenJDK Project Babylon 项目正在研究和开发的一项 Java 平台功能。

什么是Triton
Triton 是一种特定领域编程模型和编译器,开发人员可以用它来编写可编译为 GPU 代码的 Python 程序。

Triton 使那些对 GPU 硬件和 GPU 特定编程语言(如 CUDA)知之甚少或毫无经验的开发人员能够编写出非常高效的并行程序。

Triton 编程模型隐藏了 CUDA 基于线程的编程模型。因此,Triton 编译器能够更好地利用 GPU 硬件,例如,优化可能需要显式同步的情况。

为了实现这种抽象化,开发人员会根据 Triton Python API 进行编程,其中的算术计算是在张量而非标量上进行的。这种张量必须具有恒定的形状、维数和大小(此外,大小必须是 2 的幂次)。

Triton 的发布公告称:
Triton 可以让开发人员以相对较少的工作量达到硬件性能的峰值;例如,它可以用来编写与 cuBLAS 性能相当的 FP16 矩阵乘法内核--这是许多 GPU 程序员无法在 25 行代码内做到的。我们的研究人员已经用它编写出了比同等 Torch 实现效率高达 2 倍的内核,我们很高兴能与社区合作,让每个人都能更方便地使用 GPU 编程。

向量加法
为了解释编程模型,我们将介绍一个简单的例子:向量加法。尽管可以用 CUDA 轻松编写,但这个示例仍具有启发性。

Triton 网站以教程的形式介绍了完整的示例,包括 Triton 如何与 PyTorch 集成。我们将重点讨论 Triton 程序。

import triton
import triton.language as tl


@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,
               # 每个程序应处理的元素数量。
                 # 注:"constexpr "可用作形状值。
               ):
    有多个
"程序 "在处理不同的数据。我们在这里识别哪个程序
    #:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    该程序将处理从初始数据偏移的输入
    # 例如,如果有一个长度为 256、块大小为 64 的向量,程序
    # 将分别访问元素 [0:64、64:128、128:192、192:256]。
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 创建掩码,防止内存操作越界访问。
    mask = offsets < n_elements
    DRAM 中加载 x 和 y,如果输入不是数据块大小的
    倍数,则屏蔽掉多余元素。
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)

代码讲解点击标题