代码示例索引
本节汇总教程配套的 4 个可运行 Triton 示例。每个示例都对应教程中的一章重点内容,建议先读章节再跑代码,遇到不懂的地方再回到代码逐行对照。
运行环境
所有示例都需要 NVIDIA / AMD GPU + 已安装 Triton 与 PyTorch。详见安装配置。
示例列表
| # | 示例 | 对应章节 | 难度 | 演示要点 |
|---|---|---|---|---|
| 01 | 向量加法 | 第 4 章 基础算子 | ⭐ | @triton.jit、program_id、mask、benchmark |
| 02 | 融合 Softmax | 第 7 章 算子融合 | ⭐⭐ | 行内 reduce、-inf 填充、数值稳定 softmax |
| 03 | Tiled Matmul + Autotune | 第 5 章 内存优化 / 第 6 章 自动调优 | ⭐⭐⭐ | 二维 tile、tl.dot、grouped ordering、autotune |
| 04 | FlashAttention(简化版) | 第 8 章 FlashAttention 实战 | ⭐⭐⭐⭐ | tiling + online softmax + recomputation |
01 向量加法 (Vector Addition)
Triton 入门的"hello world",对应官方教程 01-vector-add。
演示要点:
- 用
@triton.jit把普通 Python 函数编译为 GPU kernel tl.program_id+tl.arange构造块级偏移tl.load/tl.store的 mask 边界保护- 用
lambda动态计算 grid 尺寸 - 与 PyTorch 对比验证 +
triton.testing带宽基准
bash
python examples/01_vector_add.py02 融合 Softmax (Fused Softmax)
把 PyTorch 朴素 softmax 的 5 个 kernel 融合成 1 个,理论加速比 4×。
演示要点:
- 单 kernel 完成
max → sub → exp → sum → div全流程 BLOCK_SIZE = next_pow2(n_cols)让整行装入 SRAM- 用
-inf作为越界 mask 填充值(不影响 max / exp) tl.max/tl.sum等 reduction- 数值稳定的 softmax(减去 max 再 exp)
限制:要求单行能放进 SRAM(通常 n_cols ≤ 32K)。
bash
python examples/02_fused_softmax.py03 矩阵乘法 (Tiled Matmul with Autotune)
Triton 最经典的实战 kernel——"用 25 行 Python 写出媲美 cuBLAS 的 GEMM"的出处。
演示要点:
- 二维 tiling:每个 program 算输出的
BLOCK_M × BLOCK_N块 tl.dot触发 Tensor Core,fp32 累加器防精度损失@triton.autotune自动搜索最优(BLOCK_*, num_warps, num_stages)- Grouped program ordering 提升 L2 cache 命中(A100 上 +11% TFLOPS)
- 用 stride 参数支持任意 layout(行主序 / 列主序 / 转置)
- 与
torch.matmul(cuBLAS) 对比
autotune 首次开销
首次调用会跑遍所有 config,可能耗时数秒到数十秒;后续调用零开销。
bash
python examples/03_matmul.py04 FlashAttention(简化版)
Triton 高级编程的"集大成"示例,演示 FlashAttention v2 的核心思想。教学版去掉了 causal / warp_specialize / TMA 等高级特性,专注算法本身。
演示要点:
- Tiling:永远不 materialize 完整的
[N, N]attention matrix - Online softmax:用 running
(m_i, l_i)渐进式归一化,新块到来时用α = exp(m_old - m_new)修正旧累加器 - fp32 累加器:
m_i/l_i/acc全 fp32,避免长序列累加误差 - Q 块全程驻留 SRAM,K/V 在内循环滑动(FA2 的关键优化)
- 与朴素 PyTorch attention 对比(
atol=1e-2是 fp16 + 在线累加的合理阈值)
bash
python examples/04_flash_attention.py怎么用这些示例
- 先读对应章节,了解算法和优化要点
- 跑一遍确认能复现教程里的性能数字
- 改一个参数(如
BLOCK_SIZE、num_warps),观察对带宽 / TFLOPS 的影响 - 改完代码做正确性测试:每个示例都自带
test_correctness(),先确保数值对,再看性能
调试新 kernel 的好习惯
正确性优先于性能。每次改完 kernel,先 TRITON_INTERPRET=1 跑通逻辑、再跑 test_correctness()、最后才用 benchmark 看性能。详见第 9 章 最佳实践。