Triton 入门

Posted by agentd on 03-14,2025

Writen by Gemini

下面是一份详细的入门方案,包含了学习路径、关键概念、实战案例和进阶方向,希望能帮助你快速上手:

一、 学习路径

  1. Python 基础:

    • Triton 是基于 Python 的,确保你对 Python 的语法、数据结构、函数、面向对象编程等有扎实的掌握。
    • 特别关注 NumPy 库,因为 Triton 经常与 NumPy 数组进行交互。
  2. CUDA 基础 (可选但强烈推荐):

    • 虽然 Triton 简化了 GPU 编程,但了解 CUDA 的基本概念(如线程、块、网格、共享内存、全局内存)对理解 Triton 的工作原理和进行性能优化非常有帮助。
    • 可以学习 CUDA C/C++ 的基础知识。
  3. Triton 核心概念:

    • Kernel (内核): 在 GPU 上执行的并行计算单元。
    • Program (程序): 类似于 CUDA 中的线程块 (block)。
    • Grid (网格): 启动内核时指定的程序实例数量。
    • 内存模型: 了解 Triton 中指针、共享内存、全局内存的使用。
    • 内置函数: 熟悉 Triton 提供的用于并行计算、内存访问、原子操作等的内置函数。
  4. Triton 官方文档和教程:

    • 官方文档: https://triton-lang.org/ 这是最权威的学习资料,务必仔细阅读。
    • 教程: 官方文档中包含了多个教程,从基础到高级,一步步引导你学习。
  5. 实战练习:

    • 从简单的算子开始,如向量加法、矩阵乘法等。
    • 逐步尝试更复杂的算子,如卷积、Reduce 操作等。
    • 尝试优化算子的性能,比较不同实现方式的效率。

二、 关键概念详解

  1. Kernel (内核):

    • Triton kernel 是用 @triton.jit 装饰的 Python 函数。
    • 在 kernel 内部,你可以使用 Triton 提供的 API 来编写并行计算逻辑。
    • kernel 函数的参数通常是 NumPy 数组或指向 GPU 内存的指针。
  2. Program (程序):

    • Triton 中的程序类似于 CUDA 中的线程块 (block)。
    • 每个程序实例独立执行 kernel 代码。
    • 程序内的线程可以访问共享内存进行协作。
  3. Grid (网格):

    • 网格定义了启动多少个程序实例来执行 kernel。
    • 通过 triton.cdiv 函数可以方便地计算网格大小。
  4. 内存模型:

    • 指针 (Pointers): Triton 使用指针来表示 GPU 内存地址。
    • 全局内存 (Global Memory): 所有程序实例都可以访问的内存。
    • 共享内存 (Shared Memory): 同一个程序实例内的线程可以共享的内存,访问速度比全局内存快。
    • 寄存器 (Registers): 每个线程私有的内存,访问速度最快。
  5. 内置函数:

    • triton.loadtriton.store:用于从全局内存加载数据和存储数据到全局内存。
    • triton.dot:用于执行矩阵乘法。
    • triton.exptriton.logtriton.sqrt 等:数学函数。
    • triton.atomic_addtriton.atomic_max 等:原子操作。
    • triton.cdiv:向上取整除法。
    • triton.where:条件选择。
    • tl.program_id(axis): 获取当前程序在指定维度上的 ID。
    • tl.num_programs(axis): 获取在指定维度上的程序总数。

三、 实战案例:向量加法

import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector
    y_ptr,  # *Pointer* to second input vector
    output_ptr,  # *Pointer* to output vector
    n_elements,  # Size of the vector
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
                 # NOTE: `constexpr` so it can be used as a shape value
):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0
    # This program will process inputs that are offset from the initial data.
    # for instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM
    tl.store(output_ptr + offsets, output, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
    # In this case, we use a 1D grid where the size is the number of blocks
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    # NOTE:
    #  - each torch.tensor object is implicitly converted into a pointer to its first element.
    #  - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel
    #  - don't forget to pass meta-parameters as keywords arguments
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
    # running asynchronously at this point.
    return output

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

代码解析:

  1. @triton.jit: 将 add_kernel 函数标记为 Triton kernel。
  2. x_ptr, y_ptr, output_ptr: 指向输入和输出向量的指针。
  3. n_elements: 向量的元素数量。
  4. BLOCK_SIZE: 每个程序处理的元素数量,使用 tl.constexpr 修饰,表示编译时常量。
  5. pid = tl.program_id(axis=0): 获取当前程序在网格中的 ID(因为我们使用一维网格,所以 axis=0)。
  6. block_start = pid * BLOCK_SIZE: 计算当前程序处理的数据块的起始索引。
  7. offsets = block_start + tl.arange(0, BLOCK_SIZE): 生成当前程序处理的元素的偏移量。
  8. mask = offsets < n_elements: 创建一个掩码,用于防止越界访问。
  9. tl.load(x_ptr + offsets, mask=mask): 从全局内存加载数据,使用掩码避免越界。
  10. tl.store(output_ptr + offsets, output, mask=mask): 将结果写回全局内存,使用掩码避免越界。
  11. add(x, y) 函数:
    • 创建输出张量 output
    • 计算网格大小 grid
    • 调用 add_kernel[grid](...) 启动 kernel。

四、 进阶方向

  1. 性能优化:

    • 共享内存: 合理利用共享内存减少全局内存访问。
    • 循环展开 (Loop Unrolling): 减少循环开销。
    • 指令级并行 (Instruction-Level Parallelism): 尽量让计算指令并行执行。
    • 数据布局: 优化数据在内存中的布局,提高访存效率。
    • 使用更合适的block size: 可以通过benchmark来选择最合适的block size。
  2. 更复杂的算子:

    • 卷积 (Convolution): 学习如何在 Triton 中实现卷积操作,可以参考官方教程。
    • Reduce 操作: 实现如 sum, max, min 等 Reduce 操作。
    • 注意力机制 (Attention): 尝试实现 Transformer 模型中的注意力机制。
  3. 与其他框架集成:

    • 学习如何将 Triton 编写的算子集成到 PyTorch 或 TensorFlow 等深度学习框架中。
  4. 参与 Triton 社区:

    • 关注 Triton 的 GitHub 仓库,了解最新的开发动态。
    • 参与 Triton 的论坛或社区讨论,与其他开发者交流经验。

五、 学习建议

  • 动手实践: 学习编程最好的方式就是动手实践,多写代码,多尝试。
  • 阅读源码: 阅读 Triton 的示例代码和官方教程的源码,学习最佳实践。
  • 调试和分析: 学会使用 Triton 提供的调试工具和性能分析工具。
  • 持续学习: GPU 编程和并行计算是一个不断发展的领域,保持学习的热情。

希望这份入门方案能帮助你顺利入门 Triton! 如果你有任何问题,欢迎随时提问。