Skip to content

代码示例索引

本节汇总教程配套的 4 个可运行 Triton 示例。每个示例都对应教程中的一章重点内容,建议先读章节再跑代码,遇到不懂的地方再回到代码逐行对照。

运行环境

所有示例都需要 NVIDIA / AMD GPU + 已安装 Triton 与 PyTorch。详见安装配置

示例列表

#示例对应章节难度演示要点
01向量加法第 4 章 基础算子@triton.jitprogram_id、mask、benchmark
02融合 Softmax第 7 章 算子融合⭐⭐行内 reduce、-inf 填充、数值稳定 softmax
03Tiled Matmul + Autotune第 5 章 内存优化 / 第 6 章 自动调优⭐⭐⭐二维 tile、tl.dot、grouped ordering、autotune
04FlashAttention(简化版)第 8 章 FlashAttention 实战⭐⭐⭐⭐tiling + online softmax + recomputation

01 向量加法 (Vector Addition)

查看源码

Triton 入门的"hello world",对应官方教程 01-vector-add

演示要点

  • @triton.jit 把普通 Python 函数编译为 GPU kernel
  • tl.program_id + tl.arange 构造块级偏移
  • tl.load / tl.store 的 mask 边界保护
  • lambda 动态计算 grid 尺寸
  • 与 PyTorch 对比验证 + triton.testing 带宽基准
bash
python examples/01_vector_add.py

02 融合 Softmax (Fused Softmax)

查看源码

把 PyTorch 朴素 softmax 的 5 个 kernel 融合成 1 个,理论加速比 4×。

演示要点

  • 单 kernel 完成 max → sub → exp → sum → div 全流程
  • BLOCK_SIZE = next_pow2(n_cols) 让整行装入 SRAM
  • -inf 作为越界 mask 填充值(不影响 max / exp)
  • tl.max / tl.sum 等 reduction
  • 数值稳定的 softmax(减去 max 再 exp)

限制:要求单行能放进 SRAM(通常 n_cols ≤ 32K)。

bash
python examples/02_fused_softmax.py

03 矩阵乘法 (Tiled Matmul with Autotune)

查看源码

Triton 最经典的实战 kernel——"用 25 行 Python 写出媲美 cuBLAS 的 GEMM"的出处。

演示要点

  • 二维 tiling:每个 program 算输出的 BLOCK_M × BLOCK_N
  • tl.dot 触发 Tensor Core,fp32 累加器防精度损失
  • @triton.autotune 自动搜索最优 (BLOCK_*, num_warps, num_stages)
  • Grouped program ordering 提升 L2 cache 命中(A100 上 +11% TFLOPS)
  • 用 stride 参数支持任意 layout(行主序 / 列主序 / 转置)
  • torch.matmul (cuBLAS) 对比

autotune 首次开销

首次调用会跑遍所有 config,可能耗时数秒到数十秒;后续调用零开销。

bash
python examples/03_matmul.py

04 FlashAttention(简化版)

查看源码

Triton 高级编程的"集大成"示例,演示 FlashAttention v2 的核心思想。教学版去掉了 causal / warp_specialize / TMA 等高级特性,专注算法本身。

演示要点

  • Tiling:永远不 materialize 完整的 [N, N] attention matrix
  • Online softmax:用 running (m_i, l_i) 渐进式归一化,新块到来时用 α = exp(m_old - m_new) 修正旧累加器
  • fp32 累加器m_i / l_i / acc 全 fp32,避免长序列累加误差
  • Q 块全程驻留 SRAM,K/V 在内循环滑动(FA2 的关键优化)
  • 与朴素 PyTorch attention 对比(atol=1e-2 是 fp16 + 在线累加的合理阈值)
bash
python examples/04_flash_attention.py

怎么用这些示例

  1. 先读对应章节,了解算法和优化要点
  2. 跑一遍确认能复现教程里的性能数字
  3. 改一个参数(如 BLOCK_SIZEnum_warps),观察对带宽 / TFLOPS 的影响
  4. 改完代码做正确性测试:每个示例都自带 test_correctness(),先确保数值对,再看性能

调试新 kernel 的好习惯

正确性优先于性能。每次改完 kernel,先 TRITON_INTERPRET=1 跑通逻辑、再跑 test_correctness()、最后才用 benchmark 看性能。详见第 9 章 最佳实践

基于 MIT 协议发布