3. 核心概念
本章是整个教程的"心脏"。理解了 SPMD、程序实例 (program)、网格 (grid)、块 (block) 这几个概念后,后面所有的算子都只是这些概念在不同场景下的展开。建议读两遍。
3.1 SPMD:Triton 的并行模型
SPMD = Single Program, Multiple Data(单程序,多数据)。
字面意思:同一份程序代码,被并发执行很多份,每一份处理不同的数据。
这跟 CUDA 是一样的——CUDA 也是 SPMD,每个线程执行同一份核函数代码。但 Triton 把这里的"程序"颗粒度做大了。
| 概念 | CUDA | Triton |
|---|---|---|
| 最小并行单位 | 线程 (thread) | 程序实例 (program)(≈ CUDA 的 block) |
| 每个并行单位处理的数据量 | 1 个标量 | 一整块分块 (tile)(向量 / 矩阵) |
| 标识自己的 ID | threadIdx.x | tl.program_id(axis=0) |
| 总并行单位数 | blockDim.x * gridDim.x | tl.num_programs(axis=0) |
心智模型转换
- CUDA:你在写"一个线程的剧本",硬件复制几十万份让线程们去演。
- Triton:你在写"一个程序实例的剧本",硬件复制几千份让程序实例们去演,每个程序实例内部的"并行"由编译器自动展开成 warp 与线程。
3.2 程序实例:Triton 的基本调度单位
每个程序实例 (program) 是一个独立运行的实例,相当于 CUDA 的一个 thread block。
它的特点是:
- 拿到一个唯一 ID(通过
tl.program_id(axis)),决定自己处理哪块数据 - 内部看不到线程——你写的所有 Triton 张量运算(加减乘除、
tl.load/store、tl.dot)都是分块级运算 - 完全独立——不同程序实例之间不能直接通信,要通信只能借助全局内存 + 原子操作
@triton.jit
def my_kernel(...):
pid = tl.program_id(axis=0) # 我是第几个程序实例?
# 根据 pid 计算自己要处理的数据范围
# ... 处理这块数据 ...3.3 网格:怎么决定开多少个程序实例
网格 (grid) 是一个 1D / 2D / 3D 的元组,告诉 Triton 沿每个维度启动多少个程序实例。
grid_1d = (128,) # 1 维,启动 128 个程序实例
grid_2d = (64, 64) # 2 维,启动 64×64 = 4096 个程序实例
grid_3d = (16, 16, 4) # 3 维,启动 16×16×4 = 1024 个程序实例3.3.1 启动核函数的语法
my_kernel[grid](arg1, arg2, ..., BLOCK_SIZE=1024)
# ^^^^ ^^^^^^^^^^^^^^^
# 方括号传 grid constexpr 用关键字传3.3.2 grid 可以是 lambda(推荐)
如果网格大小依赖某个 constexpr(比如 BLOCK_SIZE),写成 lambda 更优雅:
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
my_kernel[grid](x, n_elements, BLOCK_SIZE=1024)meta是 Triton 在启动时自动传入的字典,包含全部 constexpr 参数triton.cdiv(a, b)是向上取整除法,等价于(a + b - 1) // b,专门用于算网格大小
3.3.3 怎么选网格维度?
经验法则:
- 向量 / 1D 数据:1 维网格,沿元素方向切
- 矩阵运算(GEMM、LayerNorm 按行):2 维网格,沿 M、N 方向切
- 批量矩阵运算(batched GEMM、Attention):3 维网格,加一维 batch
3.4 块与分块:每个程序实例一次处理的数据量
块 (block)(在矩阵语境下也叫分块 (tile))是单个程序实例一次处理的数据块的形状。
BLOCK_SIZE: tl.constexpr # 1D 算子,比如 BLOCK_SIZE=1024
BLOCK_M: tl.constexpr # 2D 算子的行方向
BLOCK_N: tl.constexpr # 2D 算子的列方向
BLOCK_K: tl.constexpr # GEMM 的 K 维度必须是 2 的幂
所有 BLOCK_* 大小都必须是 2 的幂(64、128、256、512、1024、2048、4096...)。这是 tl.arange 的硬性要求,也是编译器做向量化的前提。
3.4.1 BLOCK_SIZE 大小怎么选
- 太小(如 32):网格太大,启动开销和调度开销显著
- 太大(如 65536):单个程序实例占用寄存器太多,可能溢出到本地内存(spilling),性能反而下降
- 常用范围:1D 算子
128 ~ 4096;2D 算子(64, 64) ~ (128, 256) - 正确做法:用
@triton.autotune让编译器自动搜索(详见第 6 章)
3.5 program_id 与 num_programs:定位自己
@triton.jit
def kernel(...):
pid = tl.program_id(axis=0) # 我是第几个程序实例(沿轴 0)
n_pids = tl.num_programs(axis=0) # 沿轴 0 总共有多少程序实例
# 多维网格时
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
pid_b = tl.program_id(axis=2) # 比如 batch 维类比 CUDA
| Triton | CUDA |
|---|---|
tl.program_id(axis=0) | blockIdx.x |
tl.program_id(axis=1) | blockIdx.y |
tl.num_programs(axis=0) | gridDim.x |
| —(没有线程概念) | threadIdx.x |
3.6 内存寻址模型:指针 + 偏移
Triton 不使用 arr[i] 这样的下标语法。它使用 指针算术:
# Triton 的数据访问全靠:基址指针 + 偏移向量
offsets = tl.arange(0, BLOCK_SIZE) # 形状 (BLOCK_SIZE,) 的整数向量
ptrs = x_ptr + offsets # 指针 + 整数向量 = 指针向量
x = tl.load(ptrs) # 一次性把整块数据从 HBM 拉进 SRAM3.6.1 一维寻址:向量的情况
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements # 防越界
x = tl.load(x_ptr + offsets, mask=mask)offsets 的取值示意(设 BLOCK_SIZE=4, pid=2):
block_start = 2 * 4 = 8
tl.arange(0, 4) = [0, 1, 2, 3]
offsets = [8, 9, 10, 11]3.6.2 二维寻址:矩阵的情况
二维需要显式提供 stride(步长),把 (row, col) 翻译成线性偏移:
# 取出 A[m_start:m_start+BLOCK_M, n_start:n_start+BLOCK_N] 这一块
offs_m = m_start + tl.arange(0, BLOCK_M) # 形状 (BLOCK_M,)
offs_n = n_start + tl.arange(0, BLOCK_N) # 形状 (BLOCK_N,)
# 借助 None 广播,构造形状 (BLOCK_M, BLOCK_N) 的指针矩阵
ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
# ^^^^^^^^^^^^ ^^^^^^^^^^^^
# (BLOCK_M, 1) (1, BLOCK_N)
a_block = tl.load(ptrs) # 形状 (BLOCK_M, BLOCK_N)stride 是什么
对于一个 PyTorch tensor A,stride_am = A.stride(0)、stride_an = A.stride(1)。
- 行优先 (row-major):
stride_am = N、stride_an = 1 - 列优先 (col-major):
stride_am = 1、stride_an = M
把 stride 作为参数传给核函数,可以让同一段核函数同时支持两种布局。
3.7 Mask:边界处理的"安全带"
实际数据尺寸往往不是 BLOCK_SIZE 的整数倍。比如要处理 1000 个元素,BLOCK_SIZE=128,会启动 8 个程序实例(覆盖 1024 个偏移),但最后一个程序实例的后 24 个偏移会越界。
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
# mask 为 False 的位置:不读 HBM,直接填 other(默认 0)
# 写回时同理:
tl.store(out_ptr + offsets, value, mask=mask)
# mask 为 False 的位置:不写 HBM没有 mask 的后果:越界读取可能拿到垃圾值,越界写入可能破坏其他数据甚至段错误。永远记得加 mask——这是 Triton 编程的肌肉记忆。
3.8 内存层级:写 Triton 必须知道的硬件常识
Triton 自动管理 SRAM,但怎么切块直接决定了你的算子能否高效利用这些层级。
速度快 / 容量小 速度慢 / 容量大
寄存器 (Register) → SRAM (共享内存) → HBM (Global Memory)
<1ns, ~256KB/SM ~10ns, 192~256KB/SM ~400ns, 几十GB
(A100 192 / H100 228 / Blackwell 256)- HBM:所有程序实例都能访问,但慢。
tl.load/tl.store操作的就是这一层。 - SRAM(共享内存):单个程序实例内部加载进来的数据自动放在这里,编译器自动管理,你不需要写
__shared__。 - 寄存器:参与运算时数据驻留在寄存器,也是编译器自动安排。
"写 Triton 就是在调度数据流"
高性能 Triton 算子的核心思路只有一句:
数据从 HBM 拉进 SRAM 后,尽量复用,能不再回 HBM 就不再回。
这就是算子融合 (kernel fusion) 与 FlashAttention 类技巧的本质(详见第 7、8 章)。
3.9 把所有概念串起来:一张总览图
你写的核函数(@triton.jit 装饰的 Python 函数)
↓
启动:kernel[grid](args, BLOCK_SIZE=...)
↓
┌────────────────────────────────────────────────────────┐
│ Triton 运行时按网格启动很多个程序实例 (program) │
│ │
│ program 0 program 1 ... program N-1 │
│ pid=0 pid=1 pid=N-1 │
│ ↓ ↓ ↓ │
│ 各自计算 offsets,从 HBM 载入自己的分块 (tile) │
│ 在 SRAM / 寄存器里做分块级运算 │
│ 写回 HBM │
└────────────────────────────────────────────────────────┘
↓
每个程序实例内部的 warp / 线程调度
全部由 Triton 编译器自动完成3.10 编译器视角:从块级操作到硬件指令
到这里你应该已经能把 Triton 程序"念"出来了。但要真正写得快,必须知道这些块级操作在编译器内部长什么样——它会让你后面看 BLOCK_SIZE、num_warps、num_stages 这些 knob 不再是黑箱。
3.10.1 Triton 的五个编译阶段
@triton.jit 装饰的函数从 Python 源码到 GPU 二进制要经过五道关卡(来源:PyTorch 官方博客 Triton Kernel Compilation Stages):
Python AST ─→ TTIR ─→ TTGIR
(语法树) (块级 IR,无硬件) (块级 IR + 硬件 layout)
│
▼
LLVM IR ─→ PTX ─→ CUBIN/SASS
(标量/向量 IR) (NVPTX 汇编) (GPU 二进制)每一阶段做的事可以一句话概括:
| 阶段 | 输入 | 做什么 |
|---|---|---|
| TTIR | Python AST | 把 tl.load/tl.store/tl.dot 等编织成块级 MLIR,不感知硬件 |
| TTGIR | TTIR | 给每个 tensor 打上 #blocked / #shared / #mma 等layout 编码,跑 coalescing、loop pipelining、layout propagation 等关键 pass |
| LLVM IR | TTGIR | 把块级算子降级成 warp/线程级的标量+向量 LLVM IR |
| PTX | LLVM IR | NVPTX 汇编(人类可读,仍是虚拟 ISA) |
| CUBIN / SASS | PTX | ptxas 编译为某个 SM 架构的真正机器码 |
3.10.2 一段 TTIR 的样子
把第 4 章那个向量加法核函数拿来,TTIR 大致是这样(节选自 PyTorch 官方博客与 kernel.asm['ttir'] 实测):
TTIR 节选(向量加法)
tt.func public @add_kernel(
%arg0: !tt.ptr<f32>, // x_ptr
%arg1: !tt.ptr<f32>, // y_ptr
%arg2: !tt.ptr<f32>, // output_ptr
%arg3: i32 // n_elements
) {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32 // pid
%1 = arith.muli %0, %c1024_i32 : i32 // pid * BLOCK_SIZE
%2 = tt.make_range {start = 0, end = 1024}
: tensor<1024xi32> // tl.arange(0, 1024)
%3 = tt.splat %1 : (i32) -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32> // offsets
%5 = tt.splat %arg3 : (i32) -> tensor<1024xi32>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> // mask = offsets < n_elements
%7 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> // tl.load(x_ptr + offsets, mask)
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> // tl.load(y_ptr + offsets, mask)
%13 = arith.addf %9, %12 : tensor<1024xf32> // output = x + y
%14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> // tl.store(...)
tt.return
}读 TTIR 时记几个对应关系即可:
tl.program_id(0)→tt.get_program_id xtl.arange(0, N)→tt.make_range- 标量广播到 tensor →
tt.splat - 指针 + 整数向量 →
tt.addptr tl.load/tl.store→tt.load/tt.store- Python 的
+ - * /→arith.addf/arith.mulf/ ...
"块级"是 Triton 的护城河
注意 TTIR 里根本没有 thread / warp 的概念——所有操作都是 tensor<1024xf32> 这种"整块"的。这是 Triton 能自动做 coalescing、向量化、bank conflict 规避的根基:编译器在 TTGIR 阶段才把"整块"拆成 warp 级的 layout,从而拥有完整的自由度去重排访问。
3.10.3 怎么自己 dump 出 IR
JIT 编译后,全部中间产物都挂在 kernel.asm 字典里:
size = 1024
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output = torch.empty_like(x)
# 调用一次触发 JIT
compiled = add_kernel[(1,)](x, y, output, size, BLOCK_SIZE=1024)
print(compiled.asm.keys())
# dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])
print(compiled.asm['ttir']) # 看块级 IR
print(compiled.asm['ttgir']) # 看带 layout 的 IR
print(compiled.asm['ptx']) # 看汇编缓存位置在 ~/.triton/cache/<hash>/,直接 cat 也行。
3.11 内存合并 (Coalescing) 原理
Triton 的"自动 coalescing"听起来像魔法,但只要理解硬件机制就不神秘。
3.11.1 硬件:32 / 128-byte 内存事务
NVIDIA GPU 的全局内存访问以 sector (32 byte) 为最小粒度,L1 cache line 是 128 byte (= 4 sectors)。当一个 warp(32 个线程)执行一条 LD.GLOBAL 时,Load/Store Unit (LSU) 会:
- 收集 32 个线程要读的 32 个地址;
- 算出这些地址覆盖了哪些 32-byte sector;
- 给内存控制器发出"取这些 sector"的请求。
| 访问模式 | sector 数 | 总搬运 | 有效利用率 |
|---|---|---|---|
| 32 thread 连续读 32×4=128 B,对齐到 128 B 边界 | 4 sectors | 128 B | 100% |
| 32 thread 连续读,但起点偏 4 B(非对齐) | 5 sectors | 160 B | 80% |
| 32 thread 跨距读(stride=2,每隔 4 B 跳一个) | 8 sectors | 256 B | 50% |
| 32 thread 完全乱序(最坏情况) | 32 sectors | 1024 B | 12.5% |
数据来自 NVIDIA 官方
- 实测:同一个算子从 coalesced 到 uncoalesced,单 SM 延迟从 ~232 μs 涨到 ~540 μs(>2× 退化)。
- bandwidth-bound 算子上,这个差距可以拉到 8× 以上(来源:NVIDIA Developer Blog How to Access Global Memory Efficiently)。
3.11.2 为什么 Triton 默认就 coalesced
回看 3.10.2 的 TTIR:tt.load %8 接受的是一个 tensor<1024x!tt.ptr<f32>>——整块指针向量。编译器在 TTGIR 阶段会给这个 tensor 打 #blocked layout,按"相邻 thread 拿相邻指针"的方式切分。
举例:BLOCK_SIZE=1024、num_warps=4(默认)→ 4 个 warp × 32 thread = 128 thread,每个 thread 拿 1024/128 = 8 个元素。layout 像这样:
thread 0: elements [0, 1, 2, ..., 7] (连续 8 个)
thread 1: elements [8, 9, 10, ..., 15]
thread 2: elements [16, 17, 18, ..., 23]
...警惕,这不是"thread 0 拿 0、thread 1 拿 1"的"细粒度交错"——而是"thread 0 拿 0..7、thread 1 拿 8..15"的"粗粒度连续"。但同一时刻的 SIMD 指令仍是 thread 0 取 0、thread 1 取 8、thread 2 取 16……这些地址相隔 32 B,覆盖一个连续的 128 B 对齐段,完美 coalesced。
如果你写过 CUDA 应该会想起一句口号:
相邻 thread 访问相邻地址
Triton 把这件事自动做对了:你只需要保证 offsets = pid * BLOCK + tl.arange(0, BLOCK) 这种连续偏移,编译器就会生成 coalesced 的 LD.GLOBAL / LD.GLOBAL.E.128。
3.11.3 什么时候 coalescing 会"失效"
| 失效场景 | 触发条件 | 怎么修 |
|---|---|---|
| 非 2 的幂 stride | 二维 load 用了 stride 17 这种数 | 改用 stride 16 / 32 等 2 幂,或加 padding |
| 指针起点未对齐 16 B | 把张量切片拿到非 0 起点 | .contiguous() 后再传,或在核函数外补齐 |
| 跨距访问 (gather) | offsets = stride * tl.arange(...) 中 stride 远大于 1 | 提前 transpose 让访问连续 |
| 大 dtype + 小 BLOCK | fp64 + BLOCK=32,单 warp 只覆盖 256 B | 增大 BLOCK 至少到 128 |
可以用 Nsight Compute 验证:
ncu --metrics \
l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum.per_second,\
smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct \
python my_kernel.py后者是关键指标:100% 意味着完美 coalesced,<50% 说明在浪费带宽。
3.12 块大小选择的工程权衡
BLOCK_SIZE 是 Triton 第一参数,它一次性决定五件事:occupancy、SRAM 占用、寄存器压力、向量化宽度、尾部处理开销。下面是定量分析。
3.12.1 五个相互拉扯的约束
1. SRAM 占用
对于一个 dtype 为 T 的 tile,单副本 SRAM = BLOCK_SIZE × sizeof(T)。开 num_stages 个流水线副本后再翻 num_stages 倍。matmul 这种要同时驻留 A、B 两个 tile 的,再 ×2。
matmul SRAM ≈ (BLOCK_M × BLOCK_K + BLOCK_K × BLOCK_N) × sizeof(T) × num_stages2. 寄存器压力
每个 thread 上 tile 元素的存储要走寄存器。设单线程持有元素数为 regs_per_thread = BLOCK_SIZE / (num_warps × 32),每元素占 1 个 32-bit 寄存器(fp32 / fp16 都按 32-bit 槽算):
regs/thread (估算) = BLOCK_SIZE × dtype_factor / (num_warps × 32)
+ 累加器、临时变量等额外 ~16~32 个A100/H100 单 SM 寄存器总数 = 65536 (32-bit)。CUDA 硬件上限 = 255 regs/thread。超过约 168 regs/thread 时 occupancy < 25%(详见第 5.x 节)。
3. 向量化宽度
PTX 支持 .b32、.v2.b32、.v4.b32 三档向量宽度(最大 16 B)。BLOCK_SIZE 越大、单线程持有元素越多,编译器越倾向于生成 LD.GLOBAL.E.128 这种 16 B 向量化指令,IO 指令条数最少。
4. Occupancy
每个 SM 同时驻留的 warp 数受限于 SRAM / 寄存器 / block 数。占用率低 → 切换 latency 隐藏能力差。
5. 尾部处理开销
N=1000, BLOCK_SIZE=1024 时,1 个 program 跑 1000 元素,另外 24 个 lane 浪费。BLOCK 越大,"边角料"占比越高;BLOCK 越小,program 数越多,启动开销越显著。
3.12.2 定量预估表 (fp16, num_warps=4)
以一个简单的 elementwise 算子为例(单 tile,无 stages 副本):
| BLOCK_SIZE | SRAM/program | regs/thread (估算) | 推测 occupancy | 向量化指令 |
|---|---|---|---|---|
| 64 | 128 B | ~10 | 100% (受 block 数限) | .b32 (4 B) |
| 128 | 256 B | ~12 | 100% | .v2.b32 (8 B) |
| 256 | 512 B | ~16 | 100% | .v4.b32 (16 B) |
| 512 | 1 KB | ~24 | 100% | .v4.b32 |
| 1024 | 2 KB | ~40 | 75~100% | .v4.b32 |
| 2048 | 4 KB | ~72 | 50% | .v4.b32 |
| 4096 | 8 KB | ~136 | 25% | .v4.b32 |
数据为概念估算,实际值会因 num_warps、累加器精度、临时变量数而浮动 ±30%。精确值请用
kernel.metadata.shared和kernel.n_regs读取。
3.12.3 经验法则
- Elementwise / reduction:
BLOCK_SIZE = 1024 ~ 4096,足以打满 IO 向量化但不至于 spill。 - 2D 算子(LayerNorm/softmax 按行):行
BLOCK_N直接等于 cols(一行一个 program),如果 cols > 4096,拆 row-block。 - Matmul (BLOCK_M, BLOCK_N, BLOCK_K):常见
(128, 128, 32)~(128, 256, 64),A100 上(128, 256, 64) + num_stages=3是 fp16 GEMM 的甜点。 - Attention 类:BLOCK_M=128, BLOCK_N=64 是 FlashAttention-2 的官方推荐起点。
"调一下试试" 不是最佳实践
BLOCK_SIZE 几乎永远应当被纳入 @triton.autotune 的搜索空间。手调一个值只在原型期可接受;上生产之前一定要让 autotuner 跑过你目标 shape 的扫描——5 个候选 × 3 个 num_warps × 2 个 num_stages 不到 1 分钟就能搜完。详见第 6 章。
本章小结
- Triton 沿用 SPMD 模型,但把并行单位从"线程"提升到"程序实例 (program)"——每个程序实例一次处理一整块分块。
- 网格 (grid) 决定开多少程序实例;块大小 (BLOCK_SIZE) 决定每个程序实例一次处理多少数据,必须是 2 的幂。
- 用
tl.program_id(axis)和tl.num_programs(axis)定位与查询网格信息。 - Triton 的数据访问基于 指针 + 偏移,二维寻址要用 stride + 广播。
- 永远加 mask,防止边界越界。
- 内存层级从快到慢是 寄存器 → SRAM (共享内存) → HBM,SRAM 由编译器自动管理,但怎么切块决定了你能否充分复用 SRAM 中的数据。
下一章我们把这些概念落地——用 50 行代码写出第一个能跑、能验证、能 benchmark 的 Triton 核函数。
思考题
- 若数组长度
N = 1000、BLOCK_SIZE = 128,需要启动多少个程序实例?最后一个程序实例实际有效处理多少个元素?mask 在这里起到什么作用? - 处理一个
[1024, 512]的矩阵按行做某种运算(如 LayerNorm),你倾向于用 1D 网格还是 2D 网格?BLOCK_SIZE 怎么选? - 假设你忘了写 mask,对一个
n_elements=1000的向量执行了tl.store(out_ptr + offsets, value)。可能发生什么后果?为什么有时候看起来没事(结果还对)但实际上有 bug?