Skip to content

3. 核心概念

本章是整个教程的"心脏"。理解了 SPMD、程序实例 (program)、网格 (grid)、块 (block) 这几个概念后,后面所有的算子都只是这些概念在不同场景下的展开。建议读两遍。

3.1 SPMD:Triton 的并行模型

SPMD = Single Program, Multiple Data(单程序,多数据)

字面意思:同一份程序代码,被并发执行很多份,每一份处理不同的数据

这跟 CUDA 是一样的——CUDA 也是 SPMD,每个线程执行同一份核函数代码。但 Triton 把这里的"程序"颗粒度做大了。

概念CUDATriton
最小并行单位线程 (thread)程序实例 (program)(≈ CUDA 的 block)
每个并行单位处理的数据量1 个标量一整块分块 (tile)(向量 / 矩阵)
标识自己的 IDthreadIdx.xtl.program_id(axis=0)
总并行单位数blockDim.x * gridDim.xtl.num_programs(axis=0)

心智模型转换

  • CUDA:你在写"一个线程的剧本",硬件复制几十万份让线程们去演。
  • Triton:你在写"一个程序实例的剧本",硬件复制几千份让程序实例们去演,每个程序实例内部的"并行"由编译器自动展开成 warp 与线程。

3.2 程序实例:Triton 的基本调度单位

每个程序实例 (program) 是一个独立运行的实例,相当于 CUDA 的一个 thread block。

它的特点是:

  1. 拿到一个唯一 ID(通过 tl.program_id(axis)),决定自己处理哪块数据
  2. 内部看不到线程——你写的所有 Triton 张量运算(加减乘除、tl.load/storetl.dot)都是分块级运算
  3. 完全独立——不同程序实例之间不能直接通信,要通信只能借助全局内存 + 原子操作
python
@triton.jit
def my_kernel(...):
    pid = tl.program_id(axis=0)   # 我是第几个程序实例?
    # 根据 pid 计算自己要处理的数据范围
    # ... 处理这块数据 ...

3.3 网格:怎么决定开多少个程序实例

网格 (grid) 是一个 1D / 2D / 3D 的元组,告诉 Triton 沿每个维度启动多少个程序实例。

python
grid_1d = (128,)              # 1 维,启动 128 个程序实例
grid_2d = (64, 64)            # 2 维,启动 64×64 = 4096 个程序实例
grid_3d = (16, 16, 4)         # 3 维,启动 16×16×4 = 1024 个程序实例

3.3.1 启动核函数的语法

python
my_kernel[grid](arg1, arg2, ..., BLOCK_SIZE=1024)
#         ^^^^                    ^^^^^^^^^^^^^^^
#         方括号传 grid           constexpr 用关键字传

3.3.2 grid 可以是 lambda(推荐)

如果网格大小依赖某个 constexpr(比如 BLOCK_SIZE),写成 lambda 更优雅:

python
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
my_kernel[grid](x, n_elements, BLOCK_SIZE=1024)
  • meta 是 Triton 在启动时自动传入的字典,包含全部 constexpr 参数
  • triton.cdiv(a, b) 是向上取整除法,等价于 (a + b - 1) // b,专门用于算网格大小

3.3.3 怎么选网格维度?

经验法则:

  • 向量 / 1D 数据:1 维网格,沿元素方向切
  • 矩阵运算(GEMM、LayerNorm 按行):2 维网格,沿 M、N 方向切
  • 批量矩阵运算(batched GEMM、Attention):3 维网格,加一维 batch

3.4 块与分块:每个程序实例一次处理的数据量

块 (block)(在矩阵语境下也叫分块 (tile))是单个程序实例一次处理的数据块的形状。

python
BLOCK_SIZE: tl.constexpr  # 1D 算子,比如 BLOCK_SIZE=1024
BLOCK_M: tl.constexpr     # 2D 算子的行方向
BLOCK_N: tl.constexpr     # 2D 算子的列方向
BLOCK_K: tl.constexpr     # GEMM 的 K 维度

必须是 2 的幂

所有 BLOCK_* 大小都必须是 2 的幂(64、128、256、512、1024、2048、4096...)。这是 tl.arange 的硬性要求,也是编译器做向量化的前提。

3.4.1 BLOCK_SIZE 大小怎么选

  • 太小(如 32):网格太大,启动开销和调度开销显著
  • 太大(如 65536):单个程序实例占用寄存器太多,可能溢出到本地内存(spilling),性能反而下降
  • 常用范围:1D 算子 128 ~ 4096;2D 算子 (64, 64) ~ (128, 256)
  • 正确做法:用 @triton.autotune 让编译器自动搜索(详见第 6 章)

3.5 program_idnum_programs:定位自己

python
@triton.jit
def kernel(...):
    pid = tl.program_id(axis=0)            # 我是第几个程序实例(沿轴 0)
    n_pids = tl.num_programs(axis=0)       # 沿轴 0 总共有多少程序实例

    # 多维网格时
    pid_m = tl.program_id(axis=0)
    pid_n = tl.program_id(axis=1)
    pid_b = tl.program_id(axis=2)          # 比如 batch 维

类比 CUDA

TritonCUDA
tl.program_id(axis=0)blockIdx.x
tl.program_id(axis=1)blockIdx.y
tl.num_programs(axis=0)gridDim.x
—(没有线程概念)threadIdx.x

3.6 内存寻址模型:指针 + 偏移

Triton 不使用 arr[i] 这样的下标语法。它使用 指针算术

python
# Triton 的数据访问全靠:基址指针 + 偏移向量
offsets = tl.arange(0, BLOCK_SIZE)        # 形状 (BLOCK_SIZE,) 的整数向量
ptrs = x_ptr + offsets                    # 指针 + 整数向量 = 指针向量
x = tl.load(ptrs)                         # 一次性把整块数据从 HBM 拉进 SRAM

3.6.1 一维寻址:向量的情况

python
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements              # 防越界
x = tl.load(x_ptr + offsets, mask=mask)

offsets 的取值示意(设 BLOCK_SIZE=4, pid=2):

block_start = 2 * 4 = 8
tl.arange(0, 4) = [0, 1, 2, 3]
offsets = [8, 9, 10, 11]

3.6.2 二维寻址:矩阵的情况

二维需要显式提供 stride(步长),把 (row, col) 翻译成线性偏移:

python
# 取出 A[m_start:m_start+BLOCK_M, n_start:n_start+BLOCK_N] 这一块
offs_m = m_start + tl.arange(0, BLOCK_M)  # 形状 (BLOCK_M,)
offs_n = n_start + tl.arange(0, BLOCK_N)  # 形状 (BLOCK_N,)

# 借助 None 广播,构造形状 (BLOCK_M, BLOCK_N) 的指针矩阵
ptrs = a_ptr + offs_m[:, None] * stride_am + offs_n[None, :] * stride_an
#                ^^^^^^^^^^^^                  ^^^^^^^^^^^^
#                (BLOCK_M, 1)                   (1, BLOCK_N)

a_block = tl.load(ptrs)                   # 形状 (BLOCK_M, BLOCK_N)

stride 是什么

对于一个 PyTorch tensor Astride_am = A.stride(0)stride_an = A.stride(1)

  • 行优先 (row-major):stride_am = Nstride_an = 1
  • 列优先 (col-major):stride_am = 1stride_an = M

把 stride 作为参数传给核函数,可以让同一段核函数同时支持两种布局。

3.7 Mask:边界处理的"安全带"

实际数据尺寸往往不是 BLOCK_SIZE 的整数倍。比如要处理 1000 个元素,BLOCK_SIZE=128,会启动 8 个程序实例(覆盖 1024 个偏移),但最后一个程序实例的后 24 个偏移会越界。

python
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
# mask 为 False 的位置:不读 HBM,直接填 other(默认 0)
# 写回时同理:
tl.store(out_ptr + offsets, value, mask=mask)
# mask 为 False 的位置:不写 HBM

没有 mask 的后果:越界读取可能拿到垃圾值,越界写入可能破坏其他数据甚至段错误。永远记得加 mask——这是 Triton 编程的肌肉记忆。

3.8 内存层级:写 Triton 必须知道的硬件常识

Triton 自动管理 SRAM,但怎么切块直接决定了你的算子能否高效利用这些层级。

        速度快 / 容量小                       速度慢 / 容量大
   寄存器 (Register)    →    SRAM (共享内存)         →    HBM (Global Memory)
   <1ns, ~256KB/SM           ~10ns, 192~256KB/SM          ~400ns, 几十GB
                              (A100 192 / H100 228 / Blackwell 256)
  • HBM:所有程序实例都能访问,但tl.load / tl.store 操作的就是这一层。
  • SRAM(共享内存):单个程序实例内部加载进来的数据自动放在这里,编译器自动管理,你不需要写 __shared__
  • 寄存器:参与运算时数据驻留在寄存器,也是编译器自动安排。

"写 Triton 就是在调度数据流"

高性能 Triton 算子的核心思路只有一句:

数据从 HBM 拉进 SRAM 后,尽量复用,能不再回 HBM 就不再回

这就是算子融合 (kernel fusion) 与 FlashAttention 类技巧的本质(详见第 7、8 章)。

3.9 把所有概念串起来:一张总览图

         你写的核函数(@triton.jit 装饰的 Python 函数)

        启动:kernel[grid](args, BLOCK_SIZE=...)

  ┌────────────────────────────────────────────────────────┐
  │  Triton 运行时按网格启动很多个程序实例 (program)         │
  │                                                        │
  │  program 0       program 1       ...    program N-1   │
  │  pid=0           pid=1                  pid=N-1       │
  │   ↓                ↓                       ↓          │
  │   各自计算 offsets,从 HBM 载入自己的分块 (tile)         │
  │   在 SRAM / 寄存器里做分块级运算                        │
  │   写回 HBM                                             │
  └────────────────────────────────────────────────────────┘

          每个程序实例内部的 warp / 线程调度
              全部由 Triton 编译器自动完成

3.10 编译器视角:从块级操作到硬件指令

到这里你应该已经能把 Triton 程序"念"出来了。但要真正写得快,必须知道这些块级操作在编译器内部长什么样——它会让你后面看 BLOCK_SIZEnum_warpsnum_stages 这些 knob 不再是黑箱。

3.10.1 Triton 的五个编译阶段

@triton.jit 装饰的函数从 Python 源码到 GPU 二进制要经过五道关卡(来源:PyTorch 官方博客 Triton Kernel Compilation Stages):

Python AST          ─→  TTIR              ─→  TTGIR
(语法树)               (块级 IR,无硬件)      (块级 IR + 硬件 layout)


                            LLVM IR  ─→  PTX  ─→  CUBIN/SASS
                            (标量/向量 IR)  (NVPTX 汇编)  (GPU 二进制)

每一阶段做的事可以一句话概括:

阶段输入做什么
TTIRPython ASTtl.load/tl.store/tl.dot 等编织成块级 MLIR,不感知硬件
TTGIRTTIR给每个 tensor 打上 #blocked / #shared / #mmalayout 编码,跑 coalescing、loop pipelining、layout propagation 等关键 pass
LLVM IRTTGIR把块级算子降级成 warp/线程级的标量+向量 LLVM IR
PTXLLVM IRNVPTX 汇编(人类可读,仍是虚拟 ISA)
CUBIN / SASSPTXptxas 编译为某个 SM 架构的真正机器码

3.10.2 一段 TTIR 的样子

把第 4 章那个向量加法核函数拿来,TTIR 大致是这样(节选自 PyTorch 官方博客与 kernel.asm['ttir'] 实测):

TTIR 节选(向量加法)
text
tt.func public @add_kernel(
    %arg0: !tt.ptr<f32>,           // x_ptr
    %arg1: !tt.ptr<f32>,           // y_ptr
    %arg2: !tt.ptr<f32>,           // output_ptr
    %arg3: i32                     // n_elements
) {
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id x : i32                               // pid
    %1 = arith.muli %0, %c1024_i32 : i32                         // pid * BLOCK_SIZE
    %2 = tt.make_range {start = 0, end = 1024}
            : tensor<1024xi32>                                    // tl.arange(0, 1024)
    %3 = tt.splat %1 : (i32) -> tensor<1024xi32>
    %4 = arith.addi %3, %2 : tensor<1024xi32>                    // offsets
    %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32>
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>               // mask = offsets < n_elements
    %7 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>              // tl.load(x_ptr + offsets, mask)
    %10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>>            // tl.load(y_ptr + offsets, mask)
    %13 = arith.addf %9, %12 : tensor<1024xf32>                  // output = x + y
    %14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>>            // tl.store(...)
    tt.return
}

读 TTIR 时记几个对应关系即可:

  • tl.program_id(0)tt.get_program_id x
  • tl.arange(0, N)tt.make_range
  • 标量广播到 tensor → tt.splat
  • 指针 + 整数向量 → tt.addptr
  • tl.load / tl.storett.load / tt.store
  • Python 的 + - * /arith.addf / arith.mulf / ...

"块级"是 Triton 的护城河

注意 TTIR 里根本没有 thread / warp 的概念——所有操作都是 tensor<1024xf32> 这种"整块"的。这是 Triton 能自动做 coalescing、向量化、bank conflict 规避的根基:编译器在 TTGIR 阶段才把"整块"拆成 warp 级的 layout,从而拥有完整的自由度去重排访问。

3.10.3 怎么自己 dump 出 IR

JIT 编译后,全部中间产物都挂在 kernel.asm 字典里:

python
size = 1024
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output = torch.empty_like(x)

# 调用一次触发 JIT
compiled = add_kernel[(1,)](x, y, output, size, BLOCK_SIZE=1024)
print(compiled.asm.keys())
# dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])

print(compiled.asm['ttir'])   # 看块级 IR
print(compiled.asm['ttgir'])  # 看带 layout 的 IR
print(compiled.asm['ptx'])    # 看汇编

缓存位置在 ~/.triton/cache/<hash>/,直接 cat 也行。

3.11 内存合并 (Coalescing) 原理

Triton 的"自动 coalescing"听起来像魔法,但只要理解硬件机制就不神秘。

3.11.1 硬件:32 / 128-byte 内存事务

NVIDIA GPU 的全局内存访问以 sector (32 byte) 为最小粒度,L1 cache line 是 128 byte (= 4 sectors)。当一个 warp(32 个线程)执行一条 LD.GLOBAL 时,Load/Store Unit (LSU) 会:

  1. 收集 32 个线程要读的 32 个地址;
  2. 算出这些地址覆盖了哪些 32-byte sector;
  3. 给内存控制器发出"取这些 sector"的请求。
访问模式sector 数总搬运有效利用率
32 thread 连续读 32×4=128 B,对齐到 128 B 边界4 sectors128 B100%
32 thread 连续读,但起点偏 4 B(非对齐)5 sectors160 B80%
32 thread 跨距读(stride=2,每隔 4 B 跳一个)8 sectors256 B50%
32 thread 完全乱序(最坏情况)32 sectors1024 B12.5%

数据来自 NVIDIA 官方

  • 实测:同一个算子从 coalesced 到 uncoalesced,单 SM 延迟从 ~232 μs 涨到 ~540 μs(>2× 退化)。
  • bandwidth-bound 算子上,这个差距可以拉到 8× 以上(来源:NVIDIA Developer Blog How to Access Global Memory Efficiently)。

3.11.2 为什么 Triton 默认就 coalesced

回看 3.10.2 的 TTIR:tt.load %8 接受的是一个 tensor<1024x!tt.ptr<f32>>——整块指针向量。编译器在 TTGIR 阶段会给这个 tensor 打 #blocked layout,按"相邻 thread 拿相邻指针"的方式切分。

举例:BLOCK_SIZE=1024num_warps=4(默认)→ 4 个 warp × 32 thread = 128 thread,每个 thread 拿 1024/128 = 8 个元素。layout 像这样:

thread 0:  elements [0,   1,   2,   ...,  7]    (连续 8 个)
thread 1:  elements [8,   9,  10,   ...,  15]
thread 2:  elements [16,  17, 18,   ...,  23]
...

警惕,这不是"thread 0 拿 0、thread 1 拿 1"的"细粒度交错"——而是"thread 0 拿 0..7、thread 1 拿 8..15"的"粗粒度连续"。但同一时刻的 SIMD 指令仍是 thread 0 取 0、thread 1 取 8、thread 2 取 16……这些地址相隔 32 B,覆盖一个连续的 128 B 对齐段,完美 coalesced。

如果你写过 CUDA 应该会想起一句口号:

相邻 thread 访问相邻地址

Triton 把这件事自动做对了:你只需要保证 offsets = pid * BLOCK + tl.arange(0, BLOCK) 这种连续偏移,编译器就会生成 coalesced 的 LD.GLOBAL / LD.GLOBAL.E.128

3.11.3 什么时候 coalescing 会"失效"

失效场景触发条件怎么修
非 2 的幂 stride二维 load 用了 stride 17 这种数改用 stride 16 / 32 等 2 幂,或加 padding
指针起点未对齐 16 B把张量切片拿到非 0 起点.contiguous() 后再传,或在核函数外补齐
跨距访问 (gather)offsets = stride * tl.arange(...) 中 stride 远大于 1提前 transpose 让访问连续
大 dtype + 小 BLOCKfp64 + BLOCK=32,单 warp 只覆盖 256 B增大 BLOCK 至少到 128

可以用 Nsight Compute 验证:

bash
ncu --metrics \
    l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum.per_second,\
    smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct \
    python my_kernel.py

后者是关键指标:100% 意味着完美 coalesced,<50% 说明在浪费带宽

3.12 块大小选择的工程权衡

BLOCK_SIZE 是 Triton 第一参数,它一次性决定五件事:occupancy、SRAM 占用、寄存器压力、向量化宽度、尾部处理开销。下面是定量分析。

3.12.1 五个相互拉扯的约束

1. SRAM 占用

对于一个 dtype 为 T 的 tile,单副本 SRAM = BLOCK_SIZE × sizeof(T)。开 num_stages 个流水线副本后再翻 num_stages 倍。matmul 这种要同时驻留 A、B 两个 tile 的,再 ×2。

matmul SRAM ≈ (BLOCK_M × BLOCK_K + BLOCK_K × BLOCK_N) × sizeof(T) × num_stages

2. 寄存器压力

每个 thread 上 tile 元素的存储要走寄存器。设单线程持有元素数为 regs_per_thread = BLOCK_SIZE / (num_warps × 32),每元素占 1 个 32-bit 寄存器(fp32 / fp16 都按 32-bit 槽算):

regs/thread (估算) = BLOCK_SIZE × dtype_factor / (num_warps × 32)
                     + 累加器、临时变量等额外 ~16~32 个

A100/H100 单 SM 寄存器总数 = 65536 (32-bit)。CUDA 硬件上限 = 255 regs/thread。超过约 168 regs/thread 时 occupancy < 25%(详见第 5.x 节)。

3. 向量化宽度

PTX 支持 .b32.v2.b32.v4.b32 三档向量宽度(最大 16 B)。BLOCK_SIZE 越大、单线程持有元素越多,编译器越倾向于生成 LD.GLOBAL.E.128 这种 16 B 向量化指令,IO 指令条数最少。

4. Occupancy

每个 SM 同时驻留的 warp 数受限于 SRAM / 寄存器 / block 数。占用率低 → 切换 latency 隐藏能力差。

5. 尾部处理开销

N=1000, BLOCK_SIZE=1024 时,1 个 program 跑 1000 元素,另外 24 个 lane 浪费。BLOCK 越大,"边角料"占比越高;BLOCK 越小,program 数越多,启动开销越显著。

3.12.2 定量预估表 (fp16, num_warps=4)

以一个简单的 elementwise 算子为例(单 tile,无 stages 副本):

BLOCK_SIZESRAM/programregs/thread (估算)推测 occupancy向量化指令
64128 B~10100% (受 block 数限).b32 (4 B)
128256 B~12100%.v2.b32 (8 B)
256512 B~16100%.v4.b32 (16 B)
5121 KB~24100%.v4.b32
10242 KB~4075~100%.v4.b32
20484 KB~7250%.v4.b32
40968 KB~13625%.v4.b32

数据为概念估算,实际值会因 num_warps、累加器精度、临时变量数而浮动 ±30%。精确值请用 kernel.metadata.sharedkernel.n_regs 读取。

3.12.3 经验法则

  • Elementwise / reductionBLOCK_SIZE = 1024 ~ 4096,足以打满 IO 向量化但不至于 spill。
  • 2D 算子(LayerNorm/softmax 按行):行 BLOCK_N 直接等于 cols(一行一个 program),如果 cols > 4096,拆 row-block。
  • Matmul (BLOCK_M, BLOCK_N, BLOCK_K):常见 (128, 128, 32) ~ (128, 256, 64),A100 上 (128, 256, 64) + num_stages=3 是 fp16 GEMM 的甜点。
  • Attention 类:BLOCK_M=128, BLOCK_N=64 是 FlashAttention-2 的官方推荐起点。

"调一下试试" 不是最佳实践

BLOCK_SIZE 几乎永远应当被纳入 @triton.autotune 的搜索空间。手调一个值只在原型期可接受;上生产之前一定要让 autotuner 跑过你目标 shape 的扫描——5 个候选 × 3 个 num_warps × 2 个 num_stages 不到 1 分钟就能搜完。详见第 6 章。

本章小结

  • Triton 沿用 SPMD 模型,但把并行单位从"线程"提升到"程序实例 (program)"——每个程序实例一次处理一整块分块。
  • 网格 (grid) 决定开多少程序实例;块大小 (BLOCK_SIZE) 决定每个程序实例一次处理多少数据,必须是 2 的幂
  • tl.program_id(axis)tl.num_programs(axis) 定位与查询网格信息。
  • Triton 的数据访问基于 指针 + 偏移,二维寻址要用 stride + 广播
  • 永远加 mask,防止边界越界。
  • 内存层级从快到慢是 寄存器 → SRAM (共享内存) → HBM,SRAM 由编译器自动管理,但怎么切块决定了你能否充分复用 SRAM 中的数据。

下一章我们把这些概念落地——用 50 行代码写出第一个能跑、能验证、能 benchmark 的 Triton 核函数。

思考题

  1. 若数组长度 N = 1000BLOCK_SIZE = 128,需要启动多少个程序实例?最后一个程序实例实际有效处理多少个元素?mask 在这里起到什么作用?
  2. 处理一个 [1024, 512] 的矩阵按行做某种运算(如 LayerNorm),你倾向于用 1D 网格还是 2D 网格?BLOCK_SIZE 怎么选?
  3. 假设你忘了写 mask,对一个 n_elements=1000 的向量执行了 tl.store(out_ptr + offsets, value)。可能发生什么后果?为什么有时候看起来没事(结果还对)但实际上有 bug

基于 MIT 协议发布