1. Triton 简介
本章带你认识 Triton:它是什么、为什么需要它、与 CUDA 的本质区别在哪、又适合哪些场景。读完本章你应该能向同事用三分钟讲清楚"Triton 解决了什么问题"。
1.1 Triton 是什么
Triton 是 OpenAI 主导开源的、嵌入 Python 的领域特定语言 (DSL) 与编译器,用来编写高性能 GPU 核函数 (kernel)。
- 起源:由 Philippe Tillet 在哈佛读博期间提出(MAPL 2019 论文),加入 OpenAI 后于 2021 年 7 月发布 Triton 1.0。
- 目标:让没有 CUDA 经验的研究者,也能写出"大多数时候达到专家水准"的 GPU 代码。OpenAI 官方说过——用 25 行 Python 就能写出可媲美 cuBLAS 的 FP16 矩阵乘法。
- 编译链路:你的 Python 源码 → Triton MLIR / Triton-IR → LLVM IR → 目标设备代码(NVIDIA 输出 PTX/CUBIN,AMD 输出 AMDGCN)。Triton 直接绕过了 cuBLAS、cuDNN 等闭源库,更倾向于使用 cutlass 等开源原语。
- 生态地位:PyTorch 2.x 的
torch.compile(Inductor 后端)大量依赖 Triton 生成融合算子;FlashAttention、Unsloth、Liger Kernel、vLLM 等知名项目都基于 Triton 实现关键算子。
同名警告
本教程讲的 Triton 编程语言(triton-lang/triton,OpenAI 主导)与 NVIDIA 的 Triton Inference Server(模型推理服务框架)只是同名,二者没有任何关系。本教程之后出现的 "Triton" 都指前者。
1.2 为什么需要 Triton
要回答这个问题,得先看看不用 Triton 时的两种典型路径:
| 路径 | 优点 | 痛点 |
|---|---|---|
| 手写 CUDA C++ | 性能天花板高、生态成熟 | 学习曲线陡峭;要操心 warp、共享内存银行冲突、warp 同步、寄存器压力;调试痛苦;改算法等于重写 |
| 完全依赖 PyTorch / TensorFlow 原生算子 | 用着舒服 | 一旦你的算法不是这些库的"标准操作"(例如自定义 Attention、新颖的归一化),就会发生大量算子分解与内存往返,性能大打折扣 |
Triton 试图取中间一条路:
- 像 NumPy 一样写代码(Python 语法 + 张量级运算),但生成的是直达 PTX 的高性能核函数
- 屏蔽线程级细节:你不再操心
threadIdx、共享内存、warp 同步 - 保留关键控制权:你仍能决定数据怎么分块 (tile)、网格 (grid) 怎么切、内存怎么访问
一句话:Triton 把 GPU 编程的抽象层级从"线程"上抬到了"块",绝大多数性能优化由编译器自动完成,你只需要思考"算法该怎么分块"。
1.3 Triton vs CUDA:本质差异
1.3.1 编程范式对比
| 维度 | CUDA | Triton |
|---|---|---|
| 范式 | 标量程序 + 块化线程 (SIMT) | 块化程序 + 标量线程 (Blocked Program) |
| 你写的是什么 | 单个线程的逻辑(用 threadIdx 区分) | 一个程序实例 (program) 处理一整块分块 (tile) 的逻辑 |
| 内存合并 (coalescing) | 手动设计访问模式 | 自动 |
| 共享内存 | 手动 __shared__ 声明 + 同步 | 自动分配与同步 |
| SM 内部调度(warp 调度、寄存器分配) | 手动 | 自动 |
| 跨 SM 调度(grid 切分) | 手动 | 手动(仍由你决定) |
| Tensor Core 使用 | 手动 wmma / wgmma | 自动指令选择 |
表格的关键信息只有一行:Triton 把"线程级"的所有麻烦事自动化了,只保留"块级"的决策让你做。
1.3.2 用矩阵乘法直观对比
CUDA 的思维:为输出矩阵的每个标量元素开一个线程,线程内部做 K 维度的累加循环。
// CUDA 心智模型(伪代码)
#pragma parallel for(int m = 0; m < M; m++)
#pragma parallel for(int n = 0; n < N; n++) {
float acc = 0;
for (int k = 0; k < K; k++)
acc += A[m, k] * B[k, n];
C[m, n] = acc;
}Triton 的思维:每个程序实例处理一个 MB × NB 的输出分块,内部直接做分块级的 @ 运算。
// Triton 心智模型(伪代码)
#pragma parallel for(int m = 0; m < M; m += MB)
#pragma parallel for(int n = 0; n < N; n += NB) {
float acc[MB, NB] = 0;
for (int k = 0; k < K; k += KB)
acc += A[m:m+MB, k:k+KB] @ B[k:k+KB, n:n+NB];
C[m:m+MB, n:n+NB] = acc;
}注意第二段里没有线程的概念——acc 是一整块 MB × NB 的矩阵,A[..] @ B[..] 是一次分块级矩阵乘。这正是 Triton 的编程模型:你看见的永远是"块",编译器把这些块运算自动展开成 warp 和线程上的 PTX 指令。
怎么记忆
- CUDA:你是"线程视角",要想清楚 这个线程 做什么。
- Triton:你是"程序实例视角",要想清楚 这个程序实例 处理 哪一块 数据。
1.4 Triton 的优势与局限
优势
- 生产力:Python-like 语法,无需 C++ 工具链;新算法从想法到落地的时间通常缩短一个数量级。
- 可读性:算子代码就是算法本身,没有
__syncthreads()之类的噪音。 - 自动优化:自动内存合并、thread swizzling、向量化、共享内存分配、流水线、Tensor Core 指令选择。
- JIT + 缓存:装饰器即编译,首次调用 JIT,后续从缓存读取,迭代快。
- 跨厂商可移植:同一份 Triton 代码可在 NVIDIA / AMD(ROCm)/ 部分国产卡上运行(依赖后端实现成熟度)。
局限
也别神化 Triton
- 在 Hopper / Blackwell 等最新硬件的极致特性(FP8、TMA、wgmma 异步)上,Triton 仍可能比手写 CUDA 慢约 20%(来源:Modular 博客)。
- 不支持 inter-SM 同步——跨程序实例通信仍需走原子操作或多个核函数 launch。
- 调试体验弱于 CUDA:虽然有
TRITON_INTERPRET=1解释模式,但毕竟不如 cuda-gdb / Nsight 强大。 - Windows 主线不支持:必须用社区维护的
triton-windowsfork 或 WSL2。 - macOS 完全不支持:Mac 用户只能远程使用 Linux GPU。
1.5 Triton 在生态中的位置
可以用一张"抽象层级"表来摆放:
| 层级 | 代表 | 你写什么 |
|---|---|---|
| 框架 API | PyTorch nn.Linear、F.softmax | 调用现成算子 |
| 算子 DSL(Triton 在这一层) | Triton、CUTLASS(C++ 模板) | 写新算子或融合算子 |
| 底层 CUDA | CUDA C++ + PTX | 极致手工优化 |
Triton 既不是"另一个深度学习框架"(它不替代 PyTorch),也不是"CUDA 的替代品"(PTX 仍是最终目标),它是 PyTorch 与 PTX 之间的中间层,专门解决"写新算子或融合算子"的场景。
1.6 适合哪些场景
适合:
- 自定义 Attention 变体(FlashAttention、PagedAttention、ALiBi、自定义 mask)
- 算子融合(把多个小算子合并成一个,省去中间张量的 HBM 往返)
- 量化算子(INT8/INT4 GEMM、FP8 LayerNorm)
- 稀疏算子(block-sparse、structured sparse)
- 大模型推理优化(KV cache 操作、ROPE、RMSNorm)
- 研究新算法(论文里出现的新型归一化、新型激活函数)
不太适合:
- 简单逐元素操作(PyTorch 原生
+已经够快,Triton 并不会更快) - 极度依赖最新硬件特性的极致优化(Hopper TMA + wgmma 的最后 20% 性能)
- 不熟悉 GPU 内存层级的纯新手(理解 HBM/SRAM/寄存器 是写好 Triton 的前提)
本章小结
- Triton 是 OpenAI 开源的 Python 嵌入式 GPU DSL/编译器,目标是让算子开发不再是少数 CUDA 专家的特权。
- 它通过把抽象层级从 线程 提升到 块 (block / tile),自动化了内存合并、共享内存、warp 调度等绝大多数底层细节。
- 与 CUDA 的核心差异是 "块化程序 + 标量线程" vs "标量程序 + 块化线程"。
- 它的生态定位是 PyTorch 框架算子 与 CUDA C++ 之间的中间层,PyTorch 2.x Inductor、FlashAttention、vLLM 等关键基础设施都在用它。
- 适合自定义算子、融合、量化、大模型推理优化等场景;不适合简单逐元素操作或极致硬件特性优化。
全教程术语约定
本教程统一采用以下中文术语:核函数 (kernel)、块 (block)、网格 (grid)、分块 (tile)、共享内存 (shared memory)、程序实例 (program);warp 不翻译。 代码标识符(如 add_kernel、BLOCK_SIZE、tl.program_id、grid = lambda meta: ...)保留英文原貌。
下一章先把环境装好,让你能亲手跑出第一段 Triton 代码——纸上得来终觉浅。
思考题
- 同一个矩阵乘法核函数,用 CUDA 写需要操心
__shared__、__syncthreads()、bank conflict;用 Triton 写完全不用。这些工作"消失"了吗?还是被谁接管了? - 如果你只是想把
y = x.relu() + b这样一个简单表达式加速,应该考虑 Triton,还是直接用torch.compile?为什么?(提示:torch.compile的 Inductor 后端在底层用的是什么?) - 假设有一篇新论文提出了一种新的 Attention 变体,常规的
torch.nn.functional.scaled_dot_product_attention不支持它的 mask 模式。从开发效率和性能两个角度评估,应该选 (a) 自己写 CUDA C++ 核函数,(b) 用 PyTorch 算子拼出来,(c) 用 Triton 写一个融合核函数。说出你的理由。