13. 生产级案例深度剖析
走进 vLLM、SGLang、PyTorch Inductor、Unsloth —— 每天为亿万请求保驾的真实 Triton 代码:它们怎么写、为什么这么写、踩过哪些坑、留下哪些可复用的范式。
前 12 章我们学了 Triton 的语法、内存模型、编译器原理与 Hopper/Blackwell 新硬件。但真正能跑在生产环境的 kernel,远不止"算得对、调得快":它要面对动态 batch、变长序列、跨架构兼容、FP8/INT8 量化的数值稳定性,以及和 PyTorch autograd / torch.compile 的无缝拼装。本章把六类最具代表性的工业实现拆开摆在桌上,逐行解释设计动机与权衡,让你看完能直接把同一套模式抄进自己的项目。
本章内容概览
- 13.1 vLLM PagedAttention 的 Triton 化之路
- 13.2 SGLang RadixAttention 与 Triton 量化栈
- 13.3 PyTorch Inductor 的 Triton 代码生成
- 13.4 Unsloth / Axolotl:LoRA 微调的核函数艺术
- 13.5 FP8 / INT8 量化核函数编写模式
- 13.6 反向传播核函数实战
- 13.7 本章小结与思考题
13.1 vLLM PagedAttention 的 Triton 化之路
13.1.1 为什么 PagedAttention 改变了 LLM 推理
在 2023 年之前,LLM 推理引擎普遍把每个请求的 KV cache 分配成一段连续显存。问题立刻浮现:
- 不同请求长度差异巨大(从 128 token 到 8K token),导致外部碎片
- 预分配最大长度时,内部碎片也极严重(用户实际只用 1/4 序列)
- 实测 SOTA 引擎只能利用 20-40% 的 KV cache 显存(vLLM SOSP'23 论文数据)
vLLM 借鉴操作系统的虚拟内存分页思路提出 PagedAttention:
| 概念 | 操作系统 | PagedAttention |
|---|---|---|
| 页(Page) | 4 KB 物理页 | BLOCK_SIZE token 的 KV(默认 16) |
| 进程 | 用户进程 | 单个 request |
| 页表 | Page Table | Block Table(逻辑块号 → 物理块号) |
| 缺页 | Page Fault | "新 block 申请"(实际是 free list pop) |
KV cache 不再连续 —— 每 16 个 token 一个 block,物理上可以散落在显存任意位置;request 持有 block_table 指明顺序。外部碎片清零,内部碎片缩到最后一个 block。
论文:Efficient Memory Management for Large Language Model Serving with PagedAttention (SOSP'23)。
13.1.2 从 CUDA 到 Triton:vLLM V1 的演化
vLLM 0.6 之前的 attention kernel 是 csrc/attention/attention_kernels.cu 的纯 CUDA 实现。两个痛点推动了 Triton 化:
- 跨架构维护成本:AMD MI300、Intel Gaudi、Apple M 系列都要单独写一份
- 快速迭代的研究改动:Sliding window、ALiBi、chunked prefill 每改一次都要重新写 CUDA + 编译 wheel
vLLM V1(2025 年起)把 Triton attention backend 作为 AMD 与通用硬件的默认实现,文件在 vllm/v1/attention/backends/triton_attn.py:
来源:vLLM Triton Backend Deep Dive、Enabling vLLM V1 on AMD GPUs With Triton。
13.1.3 三阶段演进:单 kernel → 拆分 → 重新合并
Phase 1:单 kernel 同时服务 prefill + decode
最直接的写法:把 grid 设为 (tokens_in_batch, num_query_heads, query_seq_block),每个 program 处理 [BLOCK_M, BLOCK_DMODEL] 的 Q 块对所有 KV blocks 的 attention。
@triton.jit
def unified_attention_kernel(
Q_ptr, K_cache_ptr, V_cache_ptr, Out_ptr,
block_tables_ptr, # [num_seqs, max_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
num_heads, num_kv_heads, head_size,
BLOCK_SIZE: tl.constexpr, # KV cache 物理块大小(如 16)
BLOCK_M: tl.constexpr, # query token 维度的 tile
BLOCK_DMODEL: tl.constexpr,
):
seq_idx = tl.program_id(0)
head_idx = tl.program_id(1)
q_blk_idx = tl.program_id(2)
seq_len = tl.load(seq_lens_ptr + seq_idx)
block_table_offset = seq_idx * MAX_BLOCKS_PER_SEQ
# 逐 physical block 取 K/V,做 online softmax
...问题:decode 阶段 query_len = 1,但 grid 第 3 维仍按 BLOCK_M 切分;多数 program 立刻 mask 退出,SM 占用率掉到 30%。
Phase 2:拆成两个 kernel
prefix_prefill_kernel:query_len > 1时使用,含 chunked prefillpaged_attention_2d_kernel:query_len == 1时使用,grid 改成(num_seqs, num_query_heads),每个 program 处理 1 个 query token 对整段历史 KV 的 attention,BLOCK_N = 16对齐 vLLM 默认 page size
@triton.jit
def paged_attention_2d_kernel(
Q_ptr, K_cache_ptr, V_cache_ptr, Out_ptr,
block_tables_ptr, seq_lens_ptr,
scale,
num_query_heads, num_kv_heads, head_size,
BLOCK_SIZE: tl.constexpr, # 16
BLOCK_DMODEL: tl.constexpr,
):
seq_idx = tl.program_id(0)
head_idx = tl.program_id(1)
seq_len = tl.load(seq_lens_ptr + seq_idx)
num_blocks = tl.cdiv(seq_len, BLOCK_SIZE)
# 单个 query token:装进寄存器
q = tl.load(Q_ptr + seq_idx * num_query_heads * head_size
+ head_idx * head_size + tl.arange(0, BLOCK_DMODEL))
q = q * scale
# online softmax 累加器
m_i = -float("inf")
l_i = 0.0
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
kv_head_idx = head_idx // (num_query_heads // num_kv_heads) # GQA
for blk_idx in range(num_blocks):
# 查 block_table 拿到物理块号
physical_blk = tl.load(block_tables_ptr
+ seq_idx * MAX_BLOCKS_PER_SEQ + blk_idx)
# 装入这个 block 的 K (BLOCK_SIZE × head_size)
k_offs = (physical_blk * num_kv_heads * BLOCK_SIZE * head_size
+ kv_head_idx * BLOCK_SIZE * head_size
+ tl.arange(0, BLOCK_SIZE)[:, None] * head_size
+ tl.arange(0, BLOCK_DMODEL)[None, :])
k = tl.load(K_cache_ptr + k_offs)
# 计算 q·K^T(注意是单 query token 对 BLOCK_SIZE 个 key)
s = tl.sum(q[None, :] * k, axis=1)
# mask 超出 seq_len 的位置
tokens_in_blk = tl.arange(0, BLOCK_SIZE) + blk_idx * BLOCK_SIZE
s = tl.where(tokens_in_blk < seq_len, s, -float("inf"))
# online softmax 更新
m_new = tl.maximum(m_i, tl.max(s, axis=0))
alpha = tl.exp(m_i - m_new)
p = tl.exp(s - m_new)
# 装入 V,更新 acc
v = tl.load(V_cache_ptr + k_offs)
acc = acc * alpha + tl.sum(p[:, None] * v, axis=0)
l_i = l_i * alpha + tl.sum(p, axis=0)
m_i = m_new
acc = acc / l_i
tl.store(Out_ptr + seq_idx * num_query_heads * head_size
+ head_idx * head_size + tl.arange(0, BLOCK_DMODEL),
acc.to(tl.float16))实测:拆分后 throughput +3.7×(来源同上)。
Phase 3:又合回单 kernel(unified kernel)
维护两份代码代价巨大 —— bug 修复要同步、autotune 配置要双份、行为差异容易漏。vLLM 当前主线正在把两条路再次统一回单 kernel,关键技巧:用 tl.constexpr 的 IS_DECODE 分支让编译器在编译期消除 dead code,runtime 仍是两套机器码但源码只一份。
13.1.4 Prefix Caching 的实现细节
vLLM V1 默认开启 Automatic Prefix Caching (APC):
请求 1: "You are a helpful assistant. Translate to French: hello"
请求 2: "You are a helpful assistant. Summarize this text: ..."
└───── 共享前缀 ─────┘实现路径:
- 块级 hash:对每个 KV block,把 token IDs 与上一块的 hash 做链式哈希(
hash(prev_hash, tokens_in_this_block)) - 命中查询:新请求 prefill 时按块查 hash 表,命中则直接复用物理块 + bump refcount
- 驱逐:LRU on blocks,未引用且最旧的优先
代码:vllm/v1/core/kv_cache_utils.py 的 BlockHashType 与 hash_block_tokens。
chunked prefill 的已知限制
当 APC 与 chunked prefill 同时启用时,只有 prompt 的第一个 chunk 享受 prefix cache(GH #7883)。原因:后续 chunk 的 hash 计算依赖前一 chunk 已经写入 KV cache 的状态,而该状态被 chunked 切散,链式 hash 容易错位。修复方案在讨论中。
13.1.5 取走的工程模式
1. 物理 KV 与逻辑 KV 分离 → block_table 间接寻址
2. 写一份 Triton 顶替多份 CUDA → 跨架构维护成本骤降
3. 用 grid 维度形态匹配 workload 形态(prefill vs decode)→ 占用率
4. tl.constexpr 分支统一 source、分裂 binary → 既减代码又不损性能
5. 块级 hash 实现 token 级缓存复用 → 无锁、O(1) 查询13.2 SGLang RadixAttention 与 Triton 量化栈
13.2.1 RadixAttention:从分页到字典树
SGLang 提出的 RadixAttention 与 vLLM PagedAttention 解决的问题略有不同:
| 维度 | vLLM (APC) | SGLang (RadixAttention) |
|---|---|---|
| 数据结构 | 块 hash 表 | Radix Tree(压缩前缀树) |
| 粒度 | Block-level(16 token 块) | Token-level(不依赖块对齐) |
| 共享发现 | 仅命中链式块 hash | 自动找最长公共前缀 |
| Fork 支持 | 弱 | 原生(agent A/B 分支共享根) |
| 驱逐 | LRU on blocks | LRU on leaf nodes |
| 最佳场景 | 静态 system prompt | 多轮对话、Tree-of-Thought、Agent |
举例:
请求 A: [System prompt 100 tokens] + "翻译 hello"
请求 B: [System prompt 100 tokens] + "解释 hello"
请求 C: [System prompt 100 tokens] + "翻译 world"
Radix Tree:
root ──[System prompt 100]── ┬── "翻译 " ──┬── "hello" (req A)
│ └── "world" (req C)
└── "解释 hello" (req B)任意公共前缀(即使不是 16 倍数)都能复用,且新 fork 的请求几乎零启动开销。
源码:sglang/srt/mem_cache/radix_cache.py,变种 SWARadixCache(滑动窗口)、ChunkCache、MambaRadixCache。
论文:SGLang: Efficient Execution of Structured Language Model Programs (NeurIPS 2024)、LMSYS Blog。
13.2.2 SGLang 的 Triton kernel 分布
SGLang 在 NVIDIA 上默认走 FlashInfer + FlashAttention(CUDA 实现),但有大量 Triton kernel 在以下场景使用:
| 路径 | 后端 | 备注 |
|---|---|---|
| NVIDIA attention 主力 | FlashInfer | C++/CUDA |
| AMD attention | Triton | 共享 vLLM 同款 backend |
| FP8 GEMM 默认 | CUTLASS | C++ |
| FP8 GEMM Triton 可选 | Triton | --fp8-gemm-backend triton |
| W8A8 量化 / dequant | Triton | w8a8_fp8, fp8_block_quant |
| MoE expert routing | Triton | fused_moe_kernel |
| Sampling kernel | Triton | top-k/top-p with temperature |
13.2.3 fused_moe_kernel 模式解析
MoE 推理的核心是 grouped GEMM:把每个 token 路由到 top-k 个 expert,对应不同的 (M_e, K) × (K, N) 矩阵相乘。SGLang 的 Triton 实现关键设计:
@triton.autotune(configs=MOE_CONFIGS, key=["N", "K", "top_k"])
@triton.jit
def fused_moe_kernel(
A_ptr, B_ptr, C_ptr,
topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr,
num_tokens_post_padded_ptr,
N, K, EM, num_valid_tokens,
stride_am, stride_ak, stride_be, stride_bk, stride_bn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# super-grouping 提升 L2 局部性
pid_m, pid_n = swizzle(pid, num_pid_m, num_pid_n, GROUP_SIZE_M)
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return # padding 区域,全跳过
# 关键 trick 1:sorted_token_ids 已按 expert 排序,连续 BLOCK_SIZE_M 个 token 同属一 expert
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
# 关键 trick 2:用 expert_ids 找当前 tile 对应的 B 矩阵
off_experts = tl.load(expert_ids_ptr + pid_m)
# 之后是标准 GEMM 主循环 + epilogue 乘 routed weight
...要点:
- 排序+padding 把 grouped GEMM 化简为单 kernel:所有同 expert 的 token 连续摆放,每个
BLOCK_Mtile 自然命中同一 expert num_tokens_post_padded记录 padding 后的总 token 数,让 grid 一次性覆盖所有 expertMUL_ROUTED_WEIGHT: tl.constexpr编译期分支,下行/上行 FFN 用不同二进制(下行要乘 router gate,上行不乘)
这套模式被 Unsloth、Megatron、PyTorch grouped_mm 反复复用,是 MoE Triton kernel 的事实标准。
13.2.4 FP8 block-scaled GEMM
DeepSeek-V3 流行的 block-scaled FP8 在 SGLang 的 Triton 实现:
- Activation 按
[1, 128]分组、weight 按[128, 128]分组,每组一个 FP32 scale - 主 GEMM 在 Tensor Core 上做 FP8 × FP8 → FP32
- scale 在 K 维内层乘进 accumulator:每
BLOCK_K = 128步乘一次对应的(a_scale, b_scale)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_fp8_ptr + ...) # FP8
b = tl.load(b_fp8_ptr + ...) # FP8
a_s = tl.load(a_scale_ptr + k) # FP32, [BLOCK_M, 1]
b_s = tl.load(b_scale_ptr + k * stride) # FP32, [1, BLOCK_N]
acc += tl.dot(a, b, out_dtype=tl.float32) * (a_s * b_s)Dequant overhead 陷阱
* (a_s * b_s) 在 CUDA Core(FP32)而非 Tensor Core 上执行。H100 FP32 CUDA Core 仅有 FP8 Tensor Core 的 1.6% 算力,单次 dequant 等价于 ~60 个 FP8 MAC 的代价。深思的工程师会让 BLOCK_K 取较大值(≥128),把 dequant 频率压到最低;极端场景把所有 scale 收尾时一次性应用(两级 scaling),但这要求权重 fine-grain 量化误差受控。详见 13.5 节。
13.3 PyTorch Inductor 的 Triton 代码生成
13.3.1 torch.compile 的全栈视角
Python source
│
▼ TorchDynamo(字节码级符号执行)
FX Graph (高层算子)
│
▼ AOTAutograd(前向 + 反向联合捕获)
Joint Forward+Backward FX Graph
│
▼ Inductor lowering(torch/_inductor/lowering.py)
Inductor IR (pointwise / reduction / GEMM 三类节点)
│
▼ Inductor fusion pass
Fused Inductor IR
│
├─→ GEMM/Linear → CUTLASS / cuBLAS / cuBLASLt
├─→ Pointwise/Reduction → ★ Triton codegen ★
└─→ Unsupported → ATen / 直接调 CUDA入口:torch/_inductor/compile_fx.py,Triton codegen 在 torch/_inductor/codegen/triton.py。
13.3.2 一个最小示例:fused add + relu
import torch
def f(x, y):
return torch.relu(x + y)
compiled = torch.compile(f)
x = torch.randn(1024, 1024, device="cuda")
y = torch.randn(1024, 1024, device="cuda")
compiled(x, y)设置 TORCH_LOGS="output_code" 后 Inductor 生成(简化):
@triton.jit
def triton_poi_fused_add_relu_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)
xmask = xindex < xnumel
tmp0 = tl.load(in_ptr0 + xindex, xmask)
tmp1 = tl.load(in_ptr1 + xindex, xmask)
tmp2 = tmp0 + tmp1
tmp3 = tl.full([1], 0, tl.float32)
tmp4 = triton.language.maximum(tmp2, tmp3)
tl.store(out_ptr0 + xindex, tmp4, xmask)命名约定揭示了类型:
| 前缀 | 含义 | 典型 op |
|---|---|---|
triton_poi_* | pointwise kernel | add, mul, relu, sigmoid |
triton_red_* | reduction kernel | sum, mean, max |
triton_per_* | persistent reduction | 单 tile 装得下整行的 softmax |
triton_tem_* | template kernel | matmul template-based |
13.3.3 调试 Inductor 输出的三种武器
武器 1:TORCH_LOGS
TORCH_LOGS="output_code" python my_model.py
# 等价 Python:
torch._logging.set_logs(output_code=True)直接打到 stderr,看每个 fused kernel 的 Triton 源码。
武器 2:TORCH_COMPILE_DEBUG
TORCH_COMPILE_DEBUG=1 python my_model.py在执行目录创建 torch_compile_debug/run_*/ 子目录,含:
torch_compile_debug/run_2026_05_28_10_30_00/
├── torchinductor/
│ └── model__0_inference_0/
│ ├── fx_graph_readable.py # 人类可读的 FX graph
│ ├── fx_graph_runnable.py # 可独立运行的复现脚本
│ ├── fx_graph_transformed.py # 经过 lowering 后
│ ├── ir_pre_fusion.txt # Inductor IR(融合前)
│ ├── ir_post_fusion.txt # 融合后
│ └── output_code.py # ★ 最终 Triton kernel
└── dynamo/
└── ...ir_pre_fusion.txt vs ir_post_fusion.txt 对比能直观看到哪些 op 被融合掉了,进而判断是否触发预期融合。
武器 3:TORCHINDUCTOR_CACHE_DIR
TORCHINDUCTOR_CACHE_DIR=/tmp/my_cache python my_model.py控制编译缓存位置,方便清理或在多个项目间隔离。
13.3.4 Inductor 融合策略关键规则
读 torch/_inductor/scheduler.py 后能总结出几条Inductor 决定是否融合两个节点的核心规则:
| 规则 | 例子 |
|---|---|
| 必须共享至少一个 read,或一个的输出是另一个的输入 | add(x,y) + relu(out) ✅;add(x,y) + mul(a,b) ❌ |
| 形状必须兼容广播 | [B,N] + [B,N] ✅;[B,N] + [N,B] ❌(需 transpose) |
| 不能跨越Reduction 边界(除 persistent reduction) | sum 后 relu 通常各为一个 kernel |
| buffer 复用约束:被多个下游消费的中间 buffer 必须物化 | 写出 extern_kernels.copy_ 而非融合 |
让融合更激进的实用技巧
- 把
.contiguous()放到管道最早位置,避免融合被 stride mismatch 打断 - 避免 in-place op 在
torch.compile函数体内(如x.add_(y)),会阻断融合 - 对 reduction-heavy 模型,设
torch._inductor.config.triton.unique_kernel_names = True后看命名能否合并
13.3.5 Inductor 自动调优
torch/_inductor/runtime/triton_heuristics.py 中的 triton_config_with_settings() 是 Inductor 的 mini-autotune:
# 简化版伪代码
def triton_config(size_hints, x, y=None, z=None, num_stages=1):
cfg = {"XBLOCK": x}
if y is not None: cfg["YBLOCK"] = y
if z is not None: cfg["ZBLOCK"] = z
return triton.Config(
cfg,
num_warps=num_warps_heuristic(size_hints),
num_stages=num_stages,
)它会枚举 XBLOCK ∈ {64, 128, 256, 512, 1024, 2048}、num_warps ∈ {2, 4, 8},跑一次 benchmark 选最优。结果缓存到 ~/.cache/torch_inductor/,避免每次重跑。
不要在长 benchmark 中信第一次结果
Inductor 第一次 compile 包含 autotune 时间,远慢于稳态。用 torch.compile(..., mode="max-autotune") 后必须 warm up 至少 3 次才能拿到代表性数字。
13.4 Unsloth / Axolotl:LoRA 微调的核函数艺术
13.4.1 LoRA 微调的瓶颈在哪里
LoRA 把权重更新拆为 W' = W + α·B·A(A, B 是低秩矩阵)。朴素 PyTorch 实现:
def lora_forward(x, W, A, B, scale):
return x @ W.T + scale * (x @ A.T @ B.T)PyTorch autograd 默认会保存所有中间张量用于 backward:
x @ A.T:shape[B, r],r 通常 16-64,开销小x @ W.T:shape[B, out],开销大但必存- forward activation 必存到 backward 调用
实际显存占用经常被中间 activation 拉爆。Unsloth 的解法:自定义 torch.autograd.Function,手工选择哪些张量存、哪些重算。
13.4.2 Unsloth 的五大核心优化
优化 1:Fused activation kernel(SwiGLU / GEGLU)
LLaMA 风格的 MLP 是 down(silu(gate(x)) * up(x))。Unsloth 把 silu * up 融合进单 Triton kernel,省掉中间张量:
@triton.jit
def fused_swiglu_kernel(gate_ptr, up_ptr, out_ptr, n_elements,
BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n_elements
g = tl.load(gate_ptr + offs, mask=mask)
u = tl.load(up_ptr + offs, mask=mask)
# silu(g) * u = g * sigmoid(g) * u
silu_g = g * tl.sigmoid(g)
tl.store(out_ptr + offs, silu_g * u, mask=mask)backward 同样融合:
@triton.jit
def fused_swiglu_backward_kernel(grad_out_ptr, gate_ptr, up_ptr,
grad_gate_ptr, grad_up_ptr, n_elements,
BLOCK: tl.constexpr):
...
sig_g = tl.sigmoid(g)
silu_g = g * sig_g
# d(silu_g * u)/du = silu_g
grad_u = grad_out * silu_g
# d(silu_g * u)/dg = u * (sig_g + g * sig_g * (1 - sig_g))
grad_g = grad_out * u * (sig_g * (1.0 + g * (1.0 - sig_g)))
tl.store(grad_gate_ptr + offs, grad_g, mask=mask)
tl.store(grad_up_ptr + offs, grad_u, mask=mask)关键设计:backward 只存 g 和 u,不存 silu(g) * u
重新算 sigmoid(g) 比存一个 [B*S, hidden] 的中间张量便宜得多。Recompute 显存换算力在 SwiGLU 这种 fused activation 上几乎免费。
优化 2:LoRA MLP 算子级融合
Unsloth 写了 LoRA_MLP.forward() —— base + LoRA 的 forward 与 backward 共享中间张量:
class FastLoraMLPForward(torch.autograd.Function):
@staticmethod
def forward(ctx, x, W_gate, W_up, W_down,
gate_A, gate_B, up_A, up_B, down_A, down_B, scale):
# 一次性算完 base
gate_out = x @ W_gate.T + scale * (x @ gate_A.T @ gate_B.T)
up_out = x @ W_up.T + scale * (x @ up_A.T @ up_B.T)
# SwiGLU 融合
act = fused_swiglu(gate_out, up_out)
out = act @ W_down.T + scale * (act @ down_A.T @ down_B.T)
# save_for_backward 只存最少必要张量
ctx.save_for_backward(x, gate_out, up_out, act,
W_gate, W_up, W_down,
gate_A, gate_B, up_A, up_B, down_A, down_B)
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_out):
# 一次 backward 把所有 gradient 算完,避免 PyTorch autograd 的"账本"开销
...PyTorch 默认 autograd 会为每个 @ 操作都记下 op + saved tensors,bookkeeping 本身就消耗 10-15% 时间。手写 Function 直接绕过。
优化 3:关联律重排
数学上 (B @ A) @ x ≡ B @ (A @ x),但显存差距巨大:
B @ A:[out, r] @ [r, in]=[out, in],和原 weight 一样大A @ x:[r, in] @ [in, batch]=[r, batch],r 通常 ≤64
Unsloth 强制走右结合:先 A @ x → [r, batch],再 B @ (A @ x) → [out, batch]。显存省 3-5×。
优化 4:MoE grouped_mm
设 UNSLOTH_MOE_BACKEND=grouped_mm 后启用专用 Triton kernel,借鉴 13.2.3 的 sorted+padded 模式。Unsloth 声称在 gpt-oss / Qwen3-MoE / DeepSeek R1/V3 / GLM 上提供 12× 训练速度、35% VRAM 节省。
来源:Unsloth MoE blog。
优化 5:4-bit / 8-bit 原生集成
LoRA base weight 用 NF4 / Int8 量化存储,Triton kernel 在内层主循环里在线 dequant:
# 简化示意
for k_blk in range(K_blocks):
w_q = tl.load(W_quant + ...) # int4 / uint8
w_scale = tl.load(W_scale + ...) # FP16, 每 block 一个
w = w_q.to(tl.float16) * w_scale # dequant
acc += tl.dot(x, w, acc)13.4.3 Axolotl:开源的 Unsloth-like 路径
Unsloth 只把 LoRA 路径开源(issue #1038),全参数 fine-tune 优化是商业版。Axolotl 基于 Unsloth 思路独立实现:
axolotl/integrations/lora_kernels/:fused SwiGLU/GEGLU forward+backward- 自定义
torch.autograd.Function处理 LoRA MLP/Attention 算子融合 - 同时支持 DDP / DeepSpeed / FSDP2
实测在 H100 / B200 上 LoRA 微调 2-5× 加速、80% 显存节省。
13.4.4 复制到自己项目的实战清单
1. 任何 element-wise activation + 下一步 elementwise,立即写 fused kernel
2. LoRA 路径必走 (A @ x) 优先的关联律重排
3. 自定义 autograd Function 时严格 save_for_backward 最小集
4. 中间张量能 recompute 的尽量 recompute(H100 SMEM 带宽爆炸 vs HBM 紧张)
5. 量化权重在内层主循环 dequant,不要预先物化整个 dequant 矩阵13.5 FP8 / INT8 量化核函数编写模式
13.5.1 三种 scaling 粒度
| 粒度 | scale 数量 | 精度损失 | dequant 开销 | 代表实现 |
|---|---|---|---|---|
| per-tensor | 1 | 高 | 极低(kernel 尾) | TransformerEngine 早期 |
| per-token (activation) / per-channel (weight) | M / N | 中 | 中 | vLLM w8a8_fp8 |
| per-block (e.g. 128×128) | M/128 × K/128 | 低 | 高 | DeepSeek-V3, FP8 GEMM |
13.5.2 PyTorch 官方的 GridQuant + GEMM 双 kernel 路径
PyTorch blog 给出的 H100 FP8 GEMM 经典 pipeline:
Phase 1:2D block quantization (256×256)
@triton.jit
def grid_quant_fp8_kernel(A_ptr, A_fp8_ptr, A_scale_ptr,
M, N, FP8_MAX: tl.constexpr,
BLOCK_M: tl.constexpr, # 256
BLOCK_N: tl.constexpr, # 256
SUB_TILE: tl.constexpr): # 32
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# grid-stride 求 abs_max
abs_max = 0.0
for sub_m in range(0, BLOCK_M, SUB_TILE):
for sub_n in range(0, BLOCK_N, SUB_TILE):
offs_m = pid_m * BLOCK_M + sub_m + tl.arange(0, SUB_TILE)
offs_n = pid_n * BLOCK_N + sub_n + tl.arange(0, SUB_TILE)
a = tl.load(A_ptr + offs_m[:, None] * N + offs_n[None, :])
abs_max = tl.maximum(abs_max, tl.max(tl.abs(a)))
scale = abs_max / FP8_MAX
tl.store(A_scale_ptr + pid_m * tl.cdiv(N, BLOCK_N) + pid_n, scale)
# 第二遍:写出 FP8
inv_scale = 1.0 / scale
for sub_m in range(0, BLOCK_M, SUB_TILE):
for sub_n in range(0, BLOCK_N, SUB_TILE):
...
a_fp8 = (a * inv_scale).to(tl.float8e4nv)
tl.store(A_fp8_ptr + ..., a_fp8)Phase 2:FP8 Tensor Core GEMM + epilogue descale
@triton.jit
def fp8_gemm_kernel(A_fp8_ptr, B_fp8_ptr, C_ptr,
A_scale_ptr, B_scale_ptr,
M, N, K, ...):
...
acc = tl.zeros([BLOCK_M, BLOCK_N], tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(A_fp8_ptr + ...)
b = tl.load(B_fp8_ptr + ...)
acc += tl.dot(a, b) # FP8 × FP8 → FP32
# 在 epilogue 一次性 dequant:每 256×256 块一个 scale
a_s = tl.load(A_scale_ptr + ...) # 标量
b_s = tl.load(B_scale_ptr + ...) # 标量
acc = acc * a_s * b_s
tl.store(C_ptr + ..., acc.to(tl.bfloat16))为什么 256×256 而不是 128×128
- 块越大 → scale 数量越少 → epilogue dequant 开销越低
- 块越大 → 异常 outlier 越容易拉爆 scale → 精度损失越高
- 256×256 是 H100 FP8 在 LLM 场景上的甜点;DeepSeek-V3 选 128×128 是因为推理用,精度更敏感
13.5.3 vLLM 的 per-token-group dynamic FP8
vllm/model_executor/layers/quantization/utils/fp8_utils.py:
@triton.jit
def _per_token_group_quant_fp8(
y_ptr, y_q_ptr, y_s_ptr,
group_size, y_num_columns, y_row_stride,
eps, fp8_min, fp8_max,
use_ue8m0: tl.constexpr, # E8M0 power-of-2 scale 用于 microscaling
BLOCK: tl.constexpr,
):
pid = tl.program_id(0) # 每个 program 处理一个 group
row = pid // num_groups_per_row
col_group = pid % num_groups_per_row
offs = (row * y_row_stride
+ col_group * group_size
+ tl.arange(0, BLOCK))
mask = tl.arange(0, BLOCK) < group_size
y = tl.load(y_ptr + offs, mask=mask)
abs_max = tl.max(tl.abs(y))
scale = tl.maximum(abs_max / fp8_max, eps)
if use_ue8m0:
# MXFP8 风格:scale 量化为 power-of-2(E8M0)
scale = tl.exp2(tl.ceil(tl.log2(scale)))
y_q = (y / scale).to(tl.float8e4nv)
tl.store(y_q_ptr + offs, y_q, mask=mask)
tl.store(y_s_ptr + pid, scale)use_ue8m0=True 启用 MXFP8 微缩放 —— scale 本身用 E8M0(仅指数)表示,硬件可在 Tensor Core 内自动应用,消除 CUDA Core dequant 开销。这是 Blackwell 的关键新能力。
13.5.4 Dequant Overhead 与两级 scaling
复述 13.2.4 的核心:H100 FP32 CUDA Core 仅 FP8 Tensor Core 1.6% 算力,单次 dequant 约 60 个 FP8 MAC 代价。在 K 维 64 次循环中每次 dequant 一遍,会完全抹平 FP8 vs BF16 的增益。
两级 scaling(MOSS、COAT 等论文方案):
fine-grain scale:每 group 一个 E8M0(power-of-2)
coarse scale : 每 tensor 一个 FP32
主循环:acc += tl.dot(a_fp8, b_fp8) # 不做 dequant
(E8M0 group scale 由 Tensor Core 自动应用,零额外指令)
epilogue:acc * a_tensor_scale * b_tensor_scale # 一次性13.5.5 INT8 与 FP8 的关键差异
| 维度 | INT8 | FP8 (E4M3/E5M2) |
|---|---|---|
| 数值表示 | 对称 / 非对称 整数 | 浮点(有 exponent) |
| 动态范围 | 256 个等距值 | E4M3: ±448,更宽 |
| 异常值容忍 | 差(需 SmoothQuant 等手段) | 较好(exponent 覆盖) |
| Quant 公式 | round(x / scale) + zero_point | (x / scale).to(fp8) |
| Tensor Core | A100 / H100 / B200 都支持 | H100+ |
| 典型应用 | CNN、传统 NLP | LLM 推理 |
INT8 Triton kernel 多一层 zero_point 处理 + asymmetric 校正:
q = ((x - x_min) / scale).to(tl.int8)
deq = q.to(tl.float32) * scale + x_minSpeechmatics 的 INT8 落地经验:Fast and Accurate GPU Quantization for Transformers。
13.5.6 vLLM 的 online dynamic FP8 副作用
--quantization fp8 默认开启 activation 动态量化:每次 forward 实时算 abs_max,额外一次 reduce + scale 计算开销在小 batch 时占比可达 8-15%,长输入/高吞吐场景才回本。
何时选择哪种量化路径
- 小 batch 低延迟(chatbot):静态量化(offline calibration),避免 activation 量化的 launch 开销
- 大 batch 高吞吐(API serving):dynamic FP8 + 大 BLOCK_K,dequant 摊薄
- 训练:MXFP8 / FP4(要求 Blackwell),fine-grain scale 抗 outlier
13.6 反向传播核函数实战
13.6.1 经典模板:torch.autograd.Function + Triton
import torch
import triton
import triton.language as tl
# 假设 _attn_fwd_triton 和 _attn_bwd_triton 是写好的 Triton kernel
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale, causal):
o, lse = _attn_fwd_triton(q, k, v, sm_scale, causal)
ctx.save_for_backward(q, k, v, o, lse)
ctx.sm_scale = sm_scale
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
dq, dk, dv = _attn_bwd_triton(q, k, v, o, lse, do,
ctx.sm_scale, ctx.causal)
# 返回元组数量 = forward 入参数;不要梯度的位置返回 None
return dq, dk, dv, None, None
flash_attn = FlashAttnFunc.apply要点:
forward内部调 Triton kernel,PyTorch autograd 不会自动追踪这些计算save_for_backward仅保存 backward 必需张量(FA 只存lse,重算 P)- 返回元组长度与 forward 入参数严格相等
- 用
torch.autograd.gradcheck在 FP64 小规模上做数值校验
13.6.2 推荐路径:torch.library.triton_op + register_autograd
PyTorch 2.4+ 起的新 API,完全兼容 torch.compile 和 AOTInductor,避免 autograd.Function 与 dynamo 的若干 composability 陷阱:
import torch
import triton
@torch.library.triton_op("mylib::flash_attn", mutates_args=())
def flash_attn(q, k, v, sm_scale: float, causal: bool):
return torch.library.wrap_triton(_attn_fwd_triton)(q, k, v, sm_scale, causal)
def setup_context(ctx, inputs, output):
q, k, v, sm_scale, causal = inputs
o, lse = output
ctx.save_for_backward(q, k, v, o, lse)
ctx.sm_scale = sm_scale
ctx.causal = causal
def backward(ctx, grad_o, grad_lse):
q, k, v, o, lse = ctx.saved_tensors
dq, dk, dv = torch.library.wrap_triton(_attn_bwd_triton)(
q, k, v, o, lse, grad_o, ctx.sm_scale, ctx.causal
)
return dq, dk, dv, None, None
flash_attn.register_autograd(backward, setup_context=setup_context)为什么用 triton_op 而不是 autograd.Function
torch.compile友好:dynamo 能正确识别为不可分解 op,反向也能被aot_autograd追踪- AOTInductor 支持:可以编译成 .so 部署到无 Python 的环境
- 可序列化:进入 fx graph 后能被 ONNX、TorchScript 等导出
13.6.3 Backward Kernel 三大设计要点
要点 1:Atomic Add 用于梯度归约
多个 program instance 写同一个 dW 位置时必须用 atomic:
tl.atomic_add(dW_ptr + offs, partial_grad, mask=mask)默认 sem="acq_rel", scope="gpu"。跨 SM 写要 scope="sys"(更慢但正确)。
H100 + Triton 3.3 的已知 atomic 问题
GH #7402 报告 tl.atomic_add 在某些跨线程模式下被报告语义不正确。Workaround:先在 SMEM 内用 tl.reduce 局部归约,再 atomic 写出,**reduce 输入量降低 32×**的同时绕开 bug。
要点 2:Recompute vs Save 的权衡
| 张量 | 保存代价(HBM) | 重算代价 | 选择 |
|---|---|---|---|
FA 的 S = QK^T | [B, H, M, N],超大 | 一次 GEMM | 重算(只存 lse) |
FA 的 P = softmax(S) | 同上 | exp + 归一 | 重算 |
LayerNorm 的 mean, var | [B, S],小 | 一次 reduce | 可选(看 register pressure) |
| MatMul 的 input | 视形状 | 不可能 | 必存 |
经典原则:只存 reduce 后的统计量(lse, mean, var),原始大张量必须重算。
要点 3:Tiling 方向的选择
FA forward 按 Q 切分 program(每个 program 处理一个 Q tile,循环遍历所有 K, V)。
FA backward 的精妙之处:
dK,dV按 K/V 切分(每个 program 处理一个 K tile,循环遍历所有 Q)—— 无需 atomicdQ按 Q 切,但和 dK/dV 切分方向冲突,需要 atomic_add 累加;FlashAttention v2/v3 选择dQ用 atomic 写回 HBM
13.6.4 FA Backward Kernel 简化示例
@triton.jit
def attn_bwd_kv_kernel(
Q, K, V, O, lse, dO, # 输入
dQ, dK, dV, # 输出
sm_scale, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, CAUSAL: tl.constexpr,
):
"""每个 program 处理一个 K/V tile([BLOCK_N, D]),算 dK[n_blk], dV[n_blk]"""
start_n = tl.program_id(0)
off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
k = tl.load(K + off_n[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :])
v = tl.load(V + off_n[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :])
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
# 遍历所有 Q tile
for start_m in range(0, N_CTX, BLOCK_M):
off_m = start_m + tl.arange(0, BLOCK_M)
q = tl.load(Q + off_m[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :])
do = tl.load(dO + off_m[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :])
l = tl.load(lse + off_m) # [BLOCK_M]
# 重算 S, P(FA 不存 P)
s = tl.dot(q, tl.trans(k)) * sm_scale
if CAUSAL:
s = tl.where(off_m[:, None] >= off_n[None, :], s, -float("inf"))
p = tl.exp(s - l[:, None]) # softmax
# dV += P^T @ dO
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
# dP = dO @ V^T
dp = tl.dot(do, tl.trans(v))
# dS = P * (dP - rowsum(P * dP))
delta = tl.sum(p * dp, axis=1, keep_dims=True)
ds = p * (dp - delta) * sm_scale
# dK += dS^T @ Q
dk += tl.dot(tl.trans(ds.to(q.dtype)), q)
# dQ 需要 atomic 累加
dq_partial = tl.dot(ds.to(k.dtype), k)
tl.atomic_add(dQ + off_m[:, None] * D
+ tl.arange(0, BLOCK_DMODEL)[None, :],
dq_partial)
tl.store(dK + off_n[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :], dk)
tl.store(dV + off_n[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :], dv)实战经验
- 先写 forward + gradcheck,再写 backward:能 forward 跑对就成功一半
- backward 单独测:
torch.autograd.gradcheck(my_func, (q.double(), k.double(), v.double()), eps=1e-6),关闭 causal、用 FP64 - 大量 atomic 时考虑分 split-k:先把 dQ 写到
[NUM_SPLITS, M, D]中间 buffer,再用单独 reduce kernel 合并 —— 比纯 atomic 更快
13.6.5 自动微分的未来(实验性)
Triton GH Discussion #6913 有 Iaroslav Elistratov 的 PoC:在 Triton IR 层做自动微分。
- 对 FA2 forward 写 backward 可移除 300 行用户代码
- 对 LayerNorm 移除 120 行
- 还未合入主线,但路线明确
预计 Triton 3.5+ 会有 @triton.jit(autodiff=True) 装饰器选项,类似 JAX 的 jax.grad,自动生成 backward kernel。
13.7 本章小结
把六个真实案例串成一条线,你会发现工业级 Triton 的"道法术":
道(设计哲学):
- 写一份 Triton 顶替多份 CUDA,跨架构维护成本骤降(vLLM, SGLang)
- 物理 / 逻辑分离 + 间接寻址,把 OS 思想搬进 GPU 内存管理(PagedAttention)
- 数学等价的算子重排能省 3-5× 显存(Unsloth LoRA 关联律)
法(架构模式):
- 用
tl.constexpr分支统一 source、分裂 binary(vLLM phase 3) - sorted + padded → 单 kernel grouped GEMM(SGLang fused_moe)
- 两级 scaling 隔离 dequant overhead(MOSS / COAT FP8)
- recompute 中间张量 + 只存 reduce 后统计量(FA backward)
术(工程实践):
TORCH_COMPILE_DEBUG=1+output_code.py读懂 Inductor 在做什么torch.library.triton_op优先于torch.autograd.Function- atomic_add 高 reduction 时先 SMEM 局部归约
- LoRA 微调强制走
(A @ x)优先的右结合 - backward kernel 按 K/V 切(dK, dV 无需 atomic),dQ 用 atomic
关键对照表
| 系统 | 核心创新 | Triton 角色 | 可学到 |
|---|---|---|---|
| vLLM | PagedAttention | AMD 主力 / NVIDIA 备份 | 间接寻址 + grid 形态切换 |
| SGLang | RadixAttention | 量化栈 / MoE / AMD | sorted+padded grouped GEMM |
| PyTorch Inductor | fusion + autotune | pointwise/reduction 后端 | fusion 规则与调试技巧 |
| Unsloth | LoRA 算子级融合 | 全栈 | 算子重排 + 自定义 autograd |
| MOSS / COAT | 两级 scaling FP8 | 量化 kernel | dequant overhead 控制 |
| FlashAttention | recompute backward | forward + backward 全套 | 反向核函数设计 |
思考题
思考题 1:vLLM Phase 2 的 throughput 提升来源
vLLM 从单 kernel 改成 prefill/decode 双 kernel 后 throughput 提升 3.7×。请定量分析至少三个性能瓶颈的解除:grid 形态、寄存器占用、SMEM 利用率。如果给你一个 batch 包含 50% prefill (avg seq=512) + 50% decode (seq=1),单 kernel 在 H100 上 SM 占用率大约是多少?
参考思路
- 单 kernel grid 按最大 query_len 切,decode 的
query_len=1的 program 立刻 mask 退出,算力浪费 ~99% - 单 kernel BLOCK_M 必须足够大覆盖 prefill,导致 decode 时寄存器/SMEM 都被无效分配
- 双 kernel 后 decode 用紧致的 2D grid
(num_seqs, num_heads),每个 program 工作量饱满 - 估算:单 kernel decode 部分 SM 占用率仅 1-5%;双 kernel 后 60-80%
思考题 2:SGLang fused_moe 的 padding 浪费
fused_moe_kernel 通过 padding 让每个 expert 的 token 数对齐 BLOCK_M。若 BLOCK_M=128、有 8 个 expert,最坏情况下 padding 浪费多少 token?给定 total_tokens = 4096, num_experts = 8 且路由完全均匀,浪费率是多少?路由极不均匀(90% token 路由到 1 个 expert)时呢?
思考题 3:Inductor 融合断点定位
写一段代码 def f(x): return torch.softmax(x @ W, dim=-1) + bias,用 TORCH_COMPILE_DEBUG=1 抓 ir_pre_fusion.txt 与 ir_post_fusion.txt。请回答:
softmax是被融合成triton_per_*还是triton_red_*?决定因素是什么?@W走的是 Triton 还是 CUTLASS?为什么 Inductor 默认在大 GEMM 上不用 Triton?- 最终
+ bias能否融入 softmax kernel?还是单独成 kernel?
思考题 4:Unsloth LoRA 关联律的反例
(B @ A) @ x vs B @ (A @ x) 在 out=11008, in=4096, r=16, batch=8192 上对比显存与时间。何时前者反而更快?(提示:思考极小 batch、极小 in/out 维度的情形)
思考题 5:FP8 两级 scaling 的精度边界
设 fine-grain group 是 128 个元素的 E8M0 scale(仅指数),coarse scale 是 per-tensor FP32。如果 weight tensor 的全局 max 是 100,但某个 128-token group 的局部 max 是 0.01,dequant 时实际表示精度损失多少 bit?为什么 DeepSeek-V3 用 [128, 128] 而不是 [1024, 1024] 的 block?
思考题 6:Backward kernel atomic 还是 split-k
FA backward 的 dQ 写回有两种实现:
- A:
tl.atomic_add直接写最终位置 - B:先写到
[NUM_SPLITS, M, D]中间 buffer,再单独 reduce kernel 合并
设 BLOCK_M=64, BLOCK_N=128, NUM_HEADS=32, N_CTX=8192, NUM_SPLITS=N_CTX/BLOCK_N=64。请估算两种方案的:
- 中间 HBM 占用
- 跨 SM atomic 冲突频率
- 总 HBM 带宽消耗
什么情况下 A 更优?什么情况下 B?
参考思路
- B 方案中间 buffer =
64 × 8192 × 128 × 4B = 256 MB,仅 dQ 已很重;A 方案零额外 - A 的 atomic 冲突频率与
NUM_HEADS × NUM_K_BLOCKS成正比;高时性能严重下降 - 当
NUM_K_BLOCKS小(短序列)→ A 优;长序列 / 高 head 数 → B 优 - FlashAttention v2 论文选 A 是因为长序列下 SMEM 容量限制 B 的中间 buffer
下一章我们将转向工具链 —— 当你的 kernel 跑得不够快时,如何用 NSight Compute、Triton 内置 profiler、PyTorch profiler 三件套定位真正的瓶颈,并把"看上去对"的微优化与"实测有效"的优化分开。