Skip to content

10. Triton 编译器原理

本章是从"写 kernel"到"调 kernel"的思维跃迁。前面九章告诉你怎么写 Triton;从本章起,我们打开编译器的黑箱——理解每一个块级操作如何被降级到硬件指令、每一个 pass 在做什么决定。掌握这些,你才能在性能差 20% 的两个 config 之间说出"为什么",而不是只能"再 autotune 试试"。

本章内容概览

  • 10.1 为什么要理解编译器
  • 10.2 编译流水线总览
  • 10.3 TTIR:块级语义表示
  • 10.4 TTGIR:硬件 layout 与关键优化 pass
  • 10.5 LLVM IR → PTX → CUBIN
  • 10.6 如何 dump 各阶段 IR
  • 10.7 实例:向量加法的完整 IR 变换
  • 10.8 编译缓存:~/.triton/cache
  • 10.9 Triton 与 MLIR 的关系
  • 10.10 本章小结

10.1 为什么要理解编译器

如果你只想写出"能跑"的 kernel,前九章已经够用。但当你遇到下面这些场景,没有编译器视角就只能瞎调:

场景不懂编译器懂编译器
autotune 选出的最快 config 寄存器用了 192,occupancy 25%困惑:理论说 occupancy 越高越好知道 K 维大时 ILP 收益高于 occupancy,25% 是合理选择
num_stages=4num_stages=3 慢 10%试了 num_stages=2 还是慢直接查 SMEM 占用,看是否触发降级
改了一行循环顺序,性能翻倍神秘玄学知道 tritongpu-pipeline pass 只识别特定循环结构
同样的 kernel 在 A100 用 mma.sync、在 H100 用 wgmma不知道为什么知道 tritongpu-accelerate-matmul 按 SM 能力选择
tl.dot(a, b) 比手写 tl.sum(a[:, :, None] * b[None, :, :], axis=1) 快 100×以为是 API 差异知道前者走 tensor core,后者走 CUDA core

核心收益:编译器视角让你从"猜 + 试"升级到"看 IR + 推理",调优效率会有数量级提升。

本章不要求你成为 LLVM 工程师

读完后你应该能做到:

  1. 看一段 TTIR / TTGIR 知道它在做什么
  2. 知道每个 pass 解决什么问题、什么时候会失败
  3. kernel.asm['ptx'] 验证编译器是否生成了你期望的指令
  4. 出现奇怪性能时知道用哪个环境变量定位

10.2 编译流水线总览

@triton.jit 装饰的 Python 函数从源码到 GPU 二进制要经过 六个阶段,每一级是独立的 MLIR / LLVM 模块转换:

@triton.jit (Python AST)

    ▼  walk AST + tl.* → MLIR ops
TTIR  (Triton dialect, 机器无关)            ←─ asm['ttir']

    ▼  ConvertTritonToTritonGPU + 布局推断
TTGIR (TritonGPU + TritonNvidiaGPU dialect)  ←─ asm['ttgir']
    │                带 #blocked / #shared / #mma 等 layout encoding
    ▼  TritonGPU → LLVM dialect → LLVM IR
LLIR  (NVVM 注解 + nvvm intrinsic)           ←─ asm['llir']

    ▼  LLVM NVPTX backend
PTX   (虚拟 ISA, 文本)                       ←─ asm['ptx']

    ▼  ptxas (NVIDIA 工具链)
CUBIN (SM-specific 机器码)                   ←─ asm['cubin']

每一阶段的入口在 python/triton/backends/nvidia/compiler.pyadd_stages 函数里,Pass 在 C++ 侧注册于 python/src/passes.cc

不同 backend 的末两级不同:

Backend末段工具
NVIDIAPTX → CUBINptxas
AMDLLVM IR → AMDGCN → HSACOLLVM AMDGPU backend + lld
IntelLLVM IR → GEN dialect → SPIR-V → 机器码IGC

10.2.1 为什么要分这么多层

每一层解决不同抽象层级的问题:

  • TTIR:只关心"块级算子"——tt.loadtt.dottt.store。不知道 GPU 是 NVIDIA 还是 AMD,不知道有多少 SM。这一层做的优化(CSE、LICM、常量折叠)所有 backend 共享。
  • TTGIR:知道 GPU 类型和 SM 能力。这一层做 coalescing、software pipelining、tensor core lowering——Triton 90% 的性能秘密都在这层
  • LLVM IR:交给 LLVM 做寄存器分配、指令选择、循环展开——这些通用编译器技术成熟稳定。
  • PTX → CUBIN:NVIDIA ptxas 的事情,Triton 不操心,但 spill / register 分配的最终决策在这里。

Triton 自己不做寄存器分配

看到 TTGIR 时所有 tensor 都是"虚拟的"——没有"这个值放在 %r17 还是 %r42"这种细节。真正的寄存器分配发生在 LLVM 和 ptxas 阶段。这也是为什么有时候同一个 Triton 源码、相同 config,换个 ptxas 版本性能会变 5%。


10.3 TTIR:块级语义表示

TTIR (Triton IR) 是 Triton 编译流水线的第一站。它由 @triton.jit 装饰器解析 Python AST 时构造,输出一个 机器无关 的 MLIR 模块。

10.3.1 TTIR 的核心特征

  1. 块级 (tile-level):所有 tensor 都是 tensor<N×N×fp16> 这种整块。没有任何 thread/warp 概念
  2. 类型已确定:每个 tt.constexpr 已经被求值,每个 tensor 的 dtype 和 shape 都是静态的。
  3. 指针类型!tt.ptr<f32> 表达指针,区别于普通 LLVM 指针——它带 address space 和对齐信息。
  4. MLIR-native:使用上游 arithmathscf dialect 处理标量算术、循环、控制流。

10.3.2 常见 op 与 Python 对应

PythonTTIR op含义
tl.program_id(0)tt.get_program_id x拿到 program ID
tl.arange(0, N)tt.make_range {start=0, end=N}构造 [0, 1, ..., N-1]
x + scalar(广播)tt.splat + arith.addi标量广播到 tensor
ptr + offsetstt.addptr指针 + 整数向量
tl.load(ptr, mask)tt.load %ptr, %mask块级加载
tl.store(ptr, v, mask)tt.store %ptr, %v, %mask块级存储
tl.dot(a, b, c)tt.dot %a, %b, %c矩阵乘累加(关键 op,触发 tensor core)
tl.sum(x, axis=0)tt.reduce + regionreduce
x.to(tl.float16)arith.truncf / arith.extf类型转换
tl.where(cond, a, b)arith.select三元选择

10.3.3 TTIR 阶段的关键 pass

Pass作用
inliner把子函数 inline(Triton 没有真正的运行时函数边界)
triton-combine模式合并,如 select(cond, load(ptr, broadcast(cond), other), other) → 一条 masked load
triton-rewrite-tensor-pointer重写 block pointer API (tt.make_tensor_ptr)
triton-reorder-broadcast把 broadcast 推到 elementwise 之后,减少冗余计算
canonicalizeMLIR 标准规范化(常量折叠入口)
cseCommon Subexpression Elimination
licmLoop-Invariant Code Motion
symbol-dceDead Symbol 消除

类型推断在哪里

TTIR 没有独立的"类型推断 pass"——所有类型在 frontend(Python → TTIR 构造时)就已经确定。Triton frontend 维护一个简化的类型系统,tl.constexpr 在这里被静态求值。这就是为什么不同 constexpr 值会触发不同的编译产物:每组值都对应一份独立的 TTIR。

10.3.4 一个 minimal TTIR 示例

向量加法源码:

python
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
    pid = tl.program_id(axis=0)
    off = pid * BLOCK + tl.arange(0, BLOCK)
    mask = off < n_elements
    x = tl.load(x_ptr + off, mask=mask)
    y = tl.load(y_ptr + off, mask=mask)
    tl.store(out_ptr + off, x + y, mask=mask)

对应 TTIR(节选,BLOCK=1024):

完整 TTIR
text
tt.func public @add_kernel(%x: !tt.ptr<f32>, %y: !tt.ptr<f32>,
                            %out: !tt.ptr<f32>, %n: i32) {
  %c1024 = arith.constant 1024 : i32
  %pid = tt.get_program_id x : i32
  %base = arith.muli %pid, %c1024 : i32                          // pid * BLOCK
  %range = tt.make_range {start=0:i32, end=1024:i32}             // [0..1023]
                         : tensor<1024xi32>
  %base_v = tt.splat %base : i32 -> tensor<1024xi32>
  %off = arith.addi %base_v, %range : tensor<1024xi32>           // offsets
  %n_v = tt.splat %n : i32 -> tensor<1024xi32>
  %mask = arith.cmpi slt, %off, %n_v : tensor<1024xi32>          // mask
  %x_p = tt.splat %x : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
  %xp = tt.addptr %x_p, %off : tensor<1024x!tt.ptr<f32>>,
                                tensor<1024xi32>
  %xv = tt.load %xp, %mask : tensor<1024x!tt.ptr<f32>>           // tl.load(x)
  %y_p = tt.splat %y : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
  %yp = tt.addptr %y_p, %off : ...
  %yv = tt.load %yp, %mask : tensor<1024x!tt.ptr<f32>>           // tl.load(y)
  %sum = arith.addf %xv, %yv : tensor<1024xf32>                  // x + y
  %o_p = tt.splat %out : ...
  %op = tt.addptr %o_p, %off : ...
  tt.store %op, %sum, %mask : tensor<1024x!tt.ptr<f32>>          // tl.store
  tt.return
}

注意 TTIR 里没有任何 #blocked 之类的 layout 标记——tensor 类型只是 tensor<1024xf32>。layout 是下一阶段 (TTGIR) 才添加的。


10.4 TTGIR:硬件 layout 与关键优化 pass

TTGIR (TritonGPU IR) 是 Triton 编译器的真正核心。它在 TTIR 的基础上:

  1. 给每个 tensor 打 layout encoding:告诉硬件元素如何映射到 thread、warp、shared memory。
  2. 跑大量优化 pass:coalescing、software pipelining、tensor core lowering、warp specialization 等。
  3. 引入硬件特定 optriton_gpu.async_copy_global_to_local(对应 cp.async)、triton_nvidia_gpu.tcgen05.mma(Blackwell tensor core)等。

10.4.1 Layout 编码:三种核心 layout

#blocked — 通用分块布局

text
#blocked = #triton_gpu.blocked<{
    sizePerThread = [8],      // 每个 thread 拿 8 个连续元素
    threadsPerWarp = [32],    // 每个 warp 32 个 thread
    warpsPerCTA = [4],        // 每个 program 4 个 warp
    order = [0]               // 沿 axis 0 (最快变维)
}>

含义:一个 1024 元素的 tensor 被切成 4 warp × 32 thread × 8 元素 = 1024,每个 thread 拿连续 8 个 fp32 = 32 字节,对应 ld.global.v4.b32 两条指令(4-wide vectorized load)。

#shared — 共享内存布局

text
#shared = #triton_gpu.shared<{
    vec = 8,        // 向量化宽度
    perPhase = 1,
    maxPhase = 8,
    order = [1, 0]  // 列优先存储
}>

perPhase / maxPhase 控制 XOR swizzle 周期(避免 bank conflict,详见第 5.7 节)。

#mma / #nvidia_mma — Tensor Core 输出布局

text
#mma = #triton_gpu.nvidia_mma<{
    versionMajor = 3,         // SM90 wgmma
    warpsPerCTA = [4, 1],
    instrShape = [16, 64, 16] // 单条 wgmma 指令的 M×N×K
}>

tt.dot 的输出走这种 layout,让 reduce 沿 N 维很快。AMD 上对应 #amd_mfma

layout 就是"切菜方案"

同一个 tensor 可以有多种 layout:放寄存器的 #blocked、放 SMEM 的 #shared、Tensor Core 输出的 #mmatriton_gpu.convert_layout op 在不同 layout 之间转换。编译器的关键目标之一就是减少这种转换——每次 convert 都意味着 SMEM 来回搬运 + 同步。

10.4.2 关键优化 pass 详解

Triton 3.3 上一个 SM90 matmul kernel 的实际 TTGIR pass pipeline(来源:issue #8546):

完整 pass 顺序
convert-triton-to-tritongpu
coalesce
F32DotTC
accelerate-matmul
remove-layout-conversions
optimize-thread-locality
optimize-dot-operands
tmem-alloc
tritongpu-assign-latencies{num-stages=3}
tritongpu-schedule-loops
tritongpu-automatic-warp-specialization{num-stages=3}
tritongpu-pipeline{num-stages=3}
tritongpu-combine-tensor-select-and-if
tritongpu-hoist-tmem-alloc
triton-loop-aware-cse
tritongpu-prefetch
tritongpu-optimize-dot-operands{hoist-layout-conversion=true}
tritongpu-coalesce-async-copy
triton-nvidia-optimize-tmem-layouts
tritongpu-remove-layout-conversions

我们重点讲五个对性能影响最大的:

Pass 1:tritongpu-coalesce

分析 tt.load / tt.store 的访存 stride,把 contiguous 维度调整到最快变的轴上。

做什么:让相邻 thread 落在相邻地址,命中 128 字节 sector 合并。

没有它会怎样:标量 load 退化成多次未合并的 32B 事务,带宽利用率掉到 12.5%(参见第 3.11 节定量分析)。

怎么验证:dump TTGIR 看 #blockedorder 字段是否对齐 tensor 的 contiguous 维度。

Pass 2:tritongpu-accelerate-matmul

检查 tt.dot 的形状、dtype、SM 能力,把通用 #blocked 布局换成硬件 native 的 tensor core 布局:

架构选用指令输出 layout
SM80 (A100)ldmatrixmma.sync.m16n8k16#nvidia_mma<v2>
SM90 (H100)TMA → wgmma.m64n128k16#nvidia_mma<v3>
SM100 (B200)TMA → tcgen05.mma (tensor memory)#tmem

这个 pass 是 Triton 能跑出 cuBLAS 90%+ 性能的关键——它一手包办了从"tl.dot"到"硬件最优 tensor core 指令"的全部决策。

Pass 3:tritongpu-remove-layout-conversions

消除冗余的 convert_layout op。例如:

text
// 优化前
%a_blocked = ... : tensor<128x32xf16, #blocked>
%a_dot = convert_layout %a_blocked : tensor<128x32xf16, #dot_op>
%a_blocked2 = convert_layout %a_dot : tensor<128x32xf16, #blocked>   // 冗余
%a_dot2 = convert_layout %a_blocked2 : tensor<128x32xf16, #dot_op>   // 冗余

// 优化后
%a_dot = convert_layout %a_blocked : tensor<128x32xf16, #dot_op>

每个 convert 都是一次 SMEM 往返 + warp 同步。消除一次 convert,单次 GEMM 主循环能省几十个 cycle。

Pass 4:tritongpu-pipeline(最关键的性能 pass)

识别 for 循环中"load → use"链,把 K 次循环里的 load 与第 K-N 次的 dot 重叠。

底层变换

text
// 优化前
scf.for %k = 0 to %K step %BK {
  %a = tt.load %ap : ...
  %b = tt.load %bp : ...
  %acc = tt.dot %a, %b, %acc : ...
  // 编译器看不到 prefetch 机会
}

// 优化后(num_stages=3)
%buf_a = triton_gpu.local_alloc : !tt.memdesc<3x128x32xf16, #shared>
%buf_b = triton_gpu.local_alloc : !tt.memdesc<3x32x128xf16, #shared>

// Prologue: 预取 K=0, 1
triton_gpu.async_copy_global_to_local %ap0, %buf_a[0]
triton_gpu.async_copy_global_to_local %ap1, %buf_a[1]
...

scf.for %k = 0 to %K step %BK {
  triton_gpu.async_wait {num = 1}                  // 等 k-2 的 cp.async
  %a = triton_gpu.local_load %buf_a[%k % 3] : ...
  %b = triton_gpu.local_load %buf_b[%k % 3] : ...
  %acc = tt.dot %a, %b, %acc : ...
  triton_gpu.async_copy_global_to_local %ap_next, %buf_a[(%k+2)%3]
  triton_gpu.async_copy_global_to_local %bp_next, %buf_b[(%k+2)%3]
}

关键点

  • 申请 num_stages 份 SMEM buffer
  • tt.load 被替换成 triton_gpu.async_copy_global_to_local(lower 到 PTX cp.async
  • 在使用点插入 triton_gpu.async_wait,并立即发起下一轮 load

怎么失败

  • 循环结构不规范(有副作用、依赖 carry-dependent 控制流)→ pass 跳过
  • SMEM 不够 → 编译器报 out of resource: shared memory
  • transfer size < 4 bytes → cp.async 不支持,编译错误(参见 issue #5882)

Pass 5:tritongpu-automatic-warp-specialization(SM90+)

把 producer warp(用 TMA 加载)和 consumer warp(用 wgmma 计算)拆到不同 warpgroup,模仿 cuBLAS 的 persistent ping-pong 调度。

A100 上没有这个 pass(要 Hopper TMA)。H100 上 num_stages ≥ 3 时会自动启用,让 load 和 compute 真正并行(而不是软件流水线意义下的"看起来并行")。详见第 12 章 Hopper 特性。

10.4.3 其他重要 pass

  • tritongpu-prefetch:与 pipeline 配对,处理 SMEM → register 的 N-Buffer,底层指令 ldmatrix
  • tritongpu-assign-latencies + tritongpu-schedule-loops:3.x 引入的两段式调度——先给每条访存/计算估算 latency,再依此排"何时插 wait、何时下一次发射"。
  • tritongpu-coalesce-async-copy:合并相邻 cp.async 为更宽指令。
  • F32DotTC / TF32x3:根据 TRITON_F32_DEFAULT 把 fp32 dot 改写成 TF32(单 pass)或 TF32×3(三次 tf32 累加补偿精度)。
  • TMA 相关:tma-materializationtma-multicasttt.descriptor_load 下降到 cp.async.bulk.tensor(Hopper TMA)。

AMD backend 还有一组专属 pass:AMDGPU-stream-pipelineAMDGPU-block-pingpongAMDGPU-optimize-epilogueAMDGPU-reorder-instructions(降寄存器压力)。


10.5 LLVM IR → PTX → CUBIN

TTGIR 之后的事情主要由 LLVM 和 NVIDIA 工具链负责,Triton 只是个调用者。但了解几个关键节点能帮你调试。

10.5.1 LLVM IR 阶段

make_llir 把 TritonGPU lower 到 LLVM dialect(带 NVVM intrinsic 和 inline asm,例如 ldmatrix.sync.aligned.m8n8.x4mma.sync.aligned.m16n8k16),再 export 成 LLVM IR。

之后跑 LLVM 中端 pass:通用 SCEV、LICM、强度削弱、循环展开。这些 pass 受 DISABLE_LLVM_OPT 环境变量控制:

bash
# 关闭 Loop Strength Reduction(在部分寄存器吃紧的 kernel 上反而提速 10%)
DISABLE_LLVM_OPT="disable-lsr" python my_kernel.py

10.5.2 PTX 阶段

LLVM NVPTX backend 做:指令选择 → 寄存器分配 → 调度 → 文本化为 PTX。

注意:LLVM 生成的 register 是 虚拟无限的(PTX SSA 形式)。例如 %r1%r2、…… 可以编到几千。这一阶段不存在 spill。

10.5.3 CUBIN 阶段

ptxas 是 NVIDIA 闭源工具,做:

  1. 虚拟寄存器 → 物理寄存器:≤255 个物理 reg。
  2. 超出则 spill 到 local memory:实质走 L1 → L2 → DRAM,延迟从 1 cycle 涨到 30~400 cycle。
  3. 指令调度:对 SASS 级别的 dual-issue 做调度。
  4. 生成 CUBIN:SM-specific 机器码。

ptxas 是性能的"最后一公里"

同一份 PTX 喂给不同版本的 ptxas,性能能差 5~10%。NVIDIA 偶尔会在 CUDA 工具链更新里悄悄改 ptxas 的寄存器分配算法。如果你在做严格的性能回归测试,记得固定 CUDA 版本。

要看 ptxas 做了什么,打开调试输出:

bash
TRITON_DUMP_PTXAS_LOG=1 python my_kernel.py

会输出:

ptxas info  : Used 192 registers, 32 stack frame, 24 bytes spill stores, 24 bytes spill loads
ptxas info  : Function properties for matmul_kernel
ptxas warning : Local memory used for function 'matmul_kernel', size of stack frame: 32 bytes

三个关键数字:

  • Used N registers:单 thread 寄存器数(≤255)
  • M bytes spill stores / loads真正 spill 到 local memory 的次数(≠0 就警报)
  • stack frame:包含 spill + 显式局部数组

10.6 如何 dump 各阶段 IR

10.6.1 方式 1:Python API(已编译 kernel 上)

最方便、最精准的方式:

python
import torch, triton, triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    off = pid * BLOCK + tl.arange(0, BLOCK)
    m = off < n
    tl.store(out_ptr + off,
             tl.load(x_ptr + off, mask=m) + tl.load(y_ptr + off, mask=m),
             mask=m)

x = torch.rand(1024, device='cuda')
y = torch.rand_like(x)
out = torch.empty_like(x)

# 触发一次 JIT
compiled = add_kernel[(1,)](x, y, out, 1024, BLOCK=1024)

print(compiled.asm.keys())
# dict_keys(['ttir', 'ttgir', 'llir', 'ptx', 'cubin'])

# 看任意阶段
print(compiled.asm['ttir'])     # 块级 IR
print(compiled.asm['ttgir'])    # 带 layout 的 IR
print(compiled.asm['llir'])     # LLVM IR
print(compiled.asm['ptx'])      # PTX 汇编
# compiled.asm['cubin']         # 二进制,无法直接 print

10.6.2 方式 2:环境变量(pass-by-pass dump)

适合排查"哪个 pass 把代码变坏了":

环境变量作用
MLIR_ENABLE_DUMP=1在每个 MLIR pass 之前 dump IR;可写成 MLIR_ENABLE_DUMP=kernelName 只 dump 指定 kernel
MLIR_DUMP_PATH=/tmp/dump.mlir把 MLIR dump 写到文件而非 stderr
LLVM_IR_ENABLE_DUMP=1在每个 LLVM pass 之前 dump LLVM IR
NVPTX_ENABLE_DUMP=1dump NVPTX backend 生成的中间产物
AMDGCN_ENABLE_DUMP=1AMD 对应版本
TRITON_KERNEL_DUMP=1dump 所有阶段成品到 TRITON_DUMP_DIR(默认 ~/.triton/dump
TRITON_PRINT_AUTOTUNING=1autotune 完成后打印每个 config 的耗时和最优 config
TRITON_DUMP_PTXAS_LOG=1显示 ptxas 的寄存器/spill 报告
TRITON_ALWAYS_COMPILE=1跳过缓存(dump 时必备,否则缓存命中没有 pass 输出)
TRITON_REPRODUCER_PATH=foo.mlir把每阶段输入存成可复现 MLIR 文件,编译失败时锁定崩溃 pass
TRITON_DISABLE_LINE_INFO=1dump 时去掉 Python 行号噪声
USE_IR_LOC=ttirttgir把后续 IR 的 source location 改写到 TTIR/TTGIR 行号
TRITON_INTERPRET=1走 Python 解释器跑 kernel(可下 pdb 断点),完全跳过编译
TRITON_F32_DEFAULT={ieee,tf32,tf32x3}控制 tl.dot 的 fp32 精度策略
TRITON_CACHE_DIR=/pathTRITON_HOME=/path切换缓存目录(只读 HOME 场景必用)

常用组合(一次性看到所有 NVIDIA stages):

bash
TRITON_ALWAYS_COMPILE=1 \
MLIR_ENABLE_DUMP=1 \
LLVM_IR_ENABLE_DUMP=1 \
NVPTX_ENABLE_DUMP=1 \
TRITON_DISABLE_LINE_INFO=1 \
TRITON_DUMP_PTXAS_LOG=1 \
python my_kernel.py 2> dump.log

# 之后 grep 找特定 pass 前后差异
grep -A 50 "After Pass: tritongpu-pipeline" dump.log

10.6.3 方式 3:直接读缓存目录

bash
ls ~/.triton/cache/
# 0a3f9b...
# 1c4e2d...

ls ~/.triton/cache/0a3f9b.../
# add_kernel.ttir
# add_kernel.ttgir
# add_kernel.llir
# add_kernel.ptx
# add_kernel.cubin
# add_kernel.json     ← 元数据(num_warps、shared、constexpr 值等)

直接 cat 任意 IR 文件即可。详见 10.8 节缓存结构。


10.7 实例:向量加法的完整 IR 变换

把 10.3 的源码顺着流水线走一遍。每一阶段我们只看关键变化。

10.7.1 源码

python
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
    pid = tl.program_id(axis=0)
    off = pid * BLOCK + tl.arange(0, BLOCK)
    mask = off < n_elements
    x = tl.load(x_ptr + off, mask=mask)
    y = tl.load(y_ptr + off, mask=mask)
    tl.store(out_ptr + off, x + y, mask=mask)

调用:add_kernel[(1,)](x, y, out, 1024, BLOCK=1024),假设 num_warps=4,目标 A100 (sm_80)。

10.7.2 TTIR(机器无关)

见 10.3.4 节。要点:tensor 类型纯净,只有 tensor<1024xf32>,无 layout。

10.7.3 TTGIR(添加 layout)

text
#blocked = #triton_gpu.blocked<{
    sizePerThread = [8],
    threadsPerWarp = [32],
    warpsPerCTA = [4],
    order = [0]
}>

module attributes {
    "triton_gpu.num-warps" = 4 : i32,
    "triton_gpu.threads-per-warp" = 32 : i32
} {
  tt.func public @add_kernel(...) {
    // ... 和 TTIR 类似,但所有 tensor 类型上多了 #blocked encoding
    %xv = tt.load %xp, %mask
          : tensor<1024x!tt.ptr<f32>, #blocked>
    %sum = arith.addf %xv, %yv
           : tensor<1024xf32, #blocked>
    tt.store %op, %sum, %mask
             : tensor<1024x!tt.ptr<f32>, #blocked>
    tt.return
  }
}

关键变化

  1. 新增 #blocked layoutsizePerThread=8 意味着每 thread 拿 8 个连续 fp32(= 32 字节 = 一条 ld.global.v4.b32 × 2 的宽度)。
  2. coalesce pass 选定 order=[0]:最快变维度对齐到 axis 0,让相邻 thread 拿相邻地址。
  3. 总元素分配4 warp × 32 thread × 8 elem = 1024,完美覆盖 BLOCK=1024。

10.7.4 LLVM IR(关键片段)

text
define ptx_kernel void @add_kernel(ptr addrspace(1) %x,
                                     ptr addrspace(1) %y,
                                     ptr addrspace(1) %out,
                                     i32 %n) #0 {
  %pid = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %tid = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  ; ... 算 offset, mask ...

  ; 向量化 load (4 路)
  %15 = getelementptr float, ptr addrspace(1) %x, i64 %14
  %18 = tail call { i32, i32, i32, i32 } asm sideeffect
        "@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];",
        "=r,=r,=r,=r,l,b"(ptr addrspace(1) %15, i1 %12)

  ; ... 解构 4 个 float、做加法 ...
  %add1 = fadd float %x1, %y1
  %add2 = fadd float %x2, %y2
  %add3 = fadd float %x3, %y3
  %add4 = fadd float %x4, %y4

  ; 向量化 store (4 路)
  tail call void asm sideeffect
        "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", ...
  ret void
}

关键变化

  • tt.load 被 lower 成 4 路向量化的 inline asm ld.global.v4.b32——这是 coalesce + vectorize 双 pass 的成果。
  • 每 thread 8 个元素 = 两条 ld.global.v4.b32(每条搬 4 个 fp32 = 16 B)。
  • mask 体现为 @$5 predicate guard。

10.7.5 PTX(NVPTX backend 输出)

ptx
.visible .entry add_kernel(
    .param .u64 add_kernel_param_0,
    .param .u64 add_kernel_param_1,
    .param .u64 add_kernel_param_2,
    .param .u32 add_kernel_param_3
)
{
    ld.param.u64        %rd1, [add_kernel_param_0];
    mov.u32             %r1, %ctaid.x;
    mov.u32             %r2, %tid.x;
    shl.b32             %r9, %r1, 10;             // pid * 1024
    or.b32              %r10, %r9, %r7;           // 加上 thread-level offset
    setp.lt.s32         %p1, %r10, %r3;           // mask
    cvta.to.global.u64  %rd5, %rd1;
    cvt.s64.s32         %rd14, %r10;
    shl.b64             %rd15, %rd14, 2;          // × sizeof(fp32)
    add.s64             %rd16, %rd5, %rd15;
    @%p1 ld.global.v4.b32  {%f1, %f2, %f3, %f4}, [%rd16];
    // ... 同样取 y ...
    add.f32             %f17, %f1, %f9;
    add.f32             %f18, %f2, %f10;
    add.f32             %f19, %f3, %f11;
    add.f32             %f20, %f4, %f12;
    @%p1 st.global.v4.b32  [%rd20], {%f17, %f18, %f19, %f20};
    ret;
}

性能检查:每条 ld.global.v4.b32 一次取 16 B,一个 warp 32 lane × 16 B = 512 B = 4 个 128 B sector,完美 coalesce

10.7.6 一图概览每阶段做了什么

源码                  TTIR                  TTGIR                 PTX
─────                ─────                 ──────                ─────
tl.load(...)    →    tt.load %p, %m   →   tt.load %p, %m   →   ld.global.v4.b32
                                          (with #blocked)        (向量化 + coalesced)

                          coalesce pass 决定 order             accelerate-matmul
                          assign 8 elt/thread                   决定 v4 (16 B)
                                                                 num_warps=4 决定
                                                                 32 thread/warp × 8 = 256 elt

10.8 编译缓存:~/.triton/cache

Triton 用基于 hash 的文件系统缓存。掌握它能避免"为什么改了代码没生效""为什么磁盘满了"这类问题。

10.8.1 缓存目录结构

默认位置 $HOME/.triton/cache/,可被 TRITON_CACHE_DIRTRITON_HOME 覆盖。

~/.triton/cache/
├── 0a3f9b2c.../                       ← 一个 hash 子目录 = 一份独立编译产物
│   ├── matmul_kernel.json             ← 元数据 (num_warps, shared, constexpr 值等)
│   ├── __grp__matmul_kernel.json      ← autotune 分组信息
│   ├── matmul_kernel.ttir
│   ├── matmul_kernel.ttgir
│   ├── matmul_kernel.llir
│   ├── matmul_kernel.ptx              ← NVIDIA
│   └── matmul_kernel.cubin            ← NVIDIA (或 .amdgcn/.hsaco for AMD)
├── 1c4e2d8f.../
│   └── ...

10.8.2 缓存 key(什么时候触发重编译)

缓存 key 由以下因素 hash 而成,任意一项变化即编译新副本

  • Triton 源码版本(含 git hash)+ LLVM 版本
  • Kernel Python 源码(AST 文本化)
  • 所有 tl.constexpr 参数的实际值
  • 每个非 constexpr tensor 参数的 dtype + 是否对齐 16B(specialization)
  • num_warpsnum_stagesnum_ctasmaxnreg
  • GPUTarget(backend、compute capability、warp size)
  • 编译选项(TRITON_F32_DEFAULT 等会进 options)

隐性缓存膨胀

每种 constexpr 组合都是独立编译产物。如果你的 kernel 有 BLOCK_SIZE: tl.constexpr 而你又传了 50 种不同的 BLOCK_SIZE,cache 里会有 50 个子目录。autotune 期间所有 config 都会被编译——一个 autotune 列表 60 个 config × 100 个 shape 就能撑出几 GB。

10.8.3 缓存失效与清理

Triton 自己没有 LRU / TTL / 容量上限,缓存只增不减(参见 issue #9298——长跑服务的动态 shape 场景会撑爆磁盘)。生产环境需要外部脚本定期清理:

bash
# 简单:保留最近 30 天的
find ~/.triton/cache -mtime +30 -delete

# 进阶:按大小限制(保留最新的 2GB)
du -sh ~/.triton/cache/*/ | sort -hr | tail -n +101 | awk '{print $2}' | xargs rm -rf

强制忽略缓存(开发调试用):

bash
TRITON_ALWAYS_COMPILE=1 python my_kernel.py

完全清空:

bash
rm -rf ~/.triton/cache

10.8.4 Cache Manager 扩展接口

triton.runtime.cache 暴露 CacheManager 接口,社区扩展可以实现:

  • triton-dejavu:分布式缓存(多机共享 hash → CUBIN 映射)
  • Red Hat TCE:签名校验(防止 cubin 篡改)
  • 分级 fallback、远程缓存等
python
# 自定义 CacheManager 示例
from triton.runtime.cache import CacheManager

class MyCacheManager(CacheManager):
    def get_file(self, filename: str) -> str | None:
        # 自定义查找逻辑(如查 S3 / Redis)
        return ...

    def put(self, data, filename: str, binary: bool = True) -> str:
        # 自定义存储逻辑
        return ...

# 注册
import triton
triton.runtime.cache._cache_manager_cls = MyCacheManager

移动缓存目录会失效

Triton 把绝对路径写进了 __grp__*.json,把缓存目录 mv 到新位置后必须重编译。CI 缓存恢复时记得保持路径一致。


10.9 Triton 与 MLIR 的关系

Triton 自 2.0 起完全重写为 MLIR-native 编译器。它定义了 5+ 个 dialect:

Dialect名字作用
tritonTTIRtile 级别、机器无关。tt.loadtt.storett.dottt.make_rangett.splattt.broadcast
triton_gpuTTGIR在 TTIR op 上加 #blocked / #shared / #dot_op / #nvidia_mma / #amd_mfma 等 layout attribute;新增 triton_gpu.async_copy_global_to_locallocal_alloclocal_load 等共享内存相关 op
triton_nvidia_gpuNVIDIA 特有 op,如 tcgen05.mma(Blackwell)、barrier 指令、TMA descriptor_load
triton_amd_gpuAMD 特有 op,如 MFMA 调度提示、buffer_load
triton_proton / triton_kperfprofiler 注入 op(KPerfIR 论文)

Triton 还复用 MLIR 上游 dialect:arith(标量算术)、math(超越函数)、scf(for/if/while)、tensorbuiltinllvm(lower 终点)。

Pass Manager 直接用 mlir::PassManager,所以 Triton pass 写法与上游一致,可被 mlir-opt 单独跑。这对编译器开发者很重要:

bash
# 把 TTIR 单独喂给某个 pass 跑(无需整条流水线)
mlir-opt \
    --triton-rewrite-tensor-pointer \
    --canonicalize \
    --cse \
    my_kernel.ttir

学习 Triton 编译器的快速路径

  1. 通读 python/triton/backends/nvidia/compiler.pyadd_stages
  2. 对应每个 pass 名字去 lib/Dialect/TritonGPU/Transforms/ 找 C++ 实现
  3. MLIR_ENABLE_DUMP=1 看 pass 实际怎么改 IR
  4. Lei Zhang 的 Triton Compiler Development Tips

10.10 版本演进简表

本章基于 Triton 3.0 ~ 3.3。重要演进:

版本关键改动
2.x → 3.0引入 tritongpu-assign-latencies 两段式调度
3.1Hopper warp specialization 自动化(tritongpu-automatic-warp-specialization
3.2Blackwell SM100 支持(tcgen05.mmatmem-alloc pass)
3.3FP8 lowering 在 TTGIR 阶段统一加 layout encoding(discussion #9051)

如果你在 Triton 2.x 上:kernel.asm key 集合相同;tritongpu-pipeline 用旧版(无 latency-aware 调度);没有 warp specialization / TMA pass。

Compiler Explorer 已支持 Triton 2.3 ~ 3.3.1 各版本在线对比 IR(RFC #7560),做版本回归调试很方便:https://github.com/triton-lang/triton/issues/7560


本章小结

  • 六阶段流水线:AST → TTIR → TTGIR → LLVM IR → PTX → CUBIN。每一层只解决自己抽象级别的问题。
  • TTIR 是块级、机器无关的"标准化代码"。这层做 CSE、LICM、常量折叠等通用优化。
  • TTGIR 是 Triton 真正的优化主战场——加 layout 编码,跑 coalesce / accelerate-matmul / pipeline / warp-specialization 等关键 pass。90% 的性能秘密都在这层
  • LLVM + ptxas 负责寄存器分配和最终机器码——Triton 不参与,但通过 ptxas 日志可以验证 spill 情况。
  • dump IRkernel.asm['<stage>'](最快)或 MLIR_ENABLE_DUMP=1 + LLVM_IR_ENABLE_DUMP=1(最全)。
  • 编译缓存~/.triton/cache/,只增不减——生产环境要外部清理。
  • Triton 完全 MLIR-native,自定义了 5+ 个 dialect,pass 可用 mlir-opt 单独跑。

下一章我们把性能优化推到定量层面:用 Roofline 模型预知上限、用 Nsight Compute 定位瓶颈、用 occupancy 公式诊断寄存器压力。


思考题

  1. dump 实战:写一个 30 行的 fused LayerNorm kernel,分别 dump 它的 TTIR 和 TTGIR。对比观察:TTGIR 多出了哪些 #blocked / #shared layout?tritongpu-pipeline pass 在 num_stages=2 vs num_stages=4 时产物有什么区别?(提示:用 MLIR_ENABLE_DUMP=1 并 grep pipeline

  2. 缓存失效诊断:你写了一个生产服务,inputs 的 batch_size 在 1~64 之间动态变化,发现服务跑了 3 天后磁盘满了。具体分析:

    • (a) 为什么 batch_size 这个普通 int 参数会导致缓存爆炸?(提示:specialization)
    • (b) 怎么用 do_not_specialize 缓解?代价是什么?
    • (c) 设计一个 cron 脚本,在不影响热缓存命中的前提下定期清理。
  3. Pass 推理题:你发现一个 matmul kernel 在 A100 上 num_stages=3 跑 250 TFLOPS,把 BLOCK_K 从 32 改到 64 后只剩 180 TFLOPS。请用编译器 pass 的视角推理可能原因(至少给出 3 种假设),并写出每种假设的验证方法(dump 哪一阶段、看什么 metric)。

基于 MIT 协议发布