Skip to content

13. 生产级案例深度剖析

走进 vLLM、SGLang、PyTorch Inductor、Unsloth —— 每天为亿万请求保驾的真实 Triton 代码:它们怎么写、为什么这么写、踩过哪些坑、留下哪些可复用的范式。

前 12 章我们学了 Triton 的语法、内存模型、编译器原理与 Hopper/Blackwell 新硬件。但真正能跑在生产环境的 kernel,远不止"算得对、调得快":它要面对动态 batch、变长序列、跨架构兼容、FP8/INT8 量化的数值稳定性,以及和 PyTorch autograd / torch.compile 的无缝拼装。本章把六类最具代表性的工业实现拆开摆在桌上,逐行解释设计动机与权衡,让你看完能直接把同一套模式抄进自己的项目。

本章内容概览

  • 13.1 vLLM PagedAttention 的 Triton 化之路
  • 13.2 SGLang RadixAttention 与 Triton 量化栈
  • 13.3 PyTorch Inductor 的 Triton 代码生成
  • 13.4 Unsloth / Axolotl:LoRA 微调的核函数艺术
  • 13.5 FP8 / INT8 量化核函数编写模式
  • 13.6 反向传播核函数实战
  • 13.7 本章小结与思考题

13.1 vLLM PagedAttention 的 Triton 化之路

13.1.1 为什么 PagedAttention 改变了 LLM 推理

在 2023 年之前,LLM 推理引擎普遍把每个请求的 KV cache 分配成一段连续显存。问题立刻浮现:

  • 不同请求长度差异巨大(从 128 token 到 8K token),导致外部碎片
  • 预分配最大长度时,内部碎片也极严重(用户实际只用 1/4 序列)
  • 实测 SOTA 引擎只能利用 20-40% 的 KV cache 显存(vLLM SOSP'23 论文数据)

vLLM 借鉴操作系统的虚拟内存分页思路提出 PagedAttention:

概念操作系统PagedAttention
页(Page)4 KB 物理页BLOCK_SIZE token 的 KV(默认 16)
进程用户进程单个 request
页表Page TableBlock Table(逻辑块号 → 物理块号)
缺页Page Fault"新 block 申请"(实际是 free list pop)

KV cache 不再连续 —— 每 16 个 token 一个 block,物理上可以散落在显存任意位置;request 持有 block_table 指明顺序。外部碎片清零,内部碎片缩到最后一个 block。

论文:Efficient Memory Management for Large Language Model Serving with PagedAttention (SOSP'23)

13.1.2 从 CUDA 到 Triton:vLLM V1 的演化

vLLM 0.6 之前的 attention kernel 是 csrc/attention/attention_kernels.cu 的纯 CUDA 实现。两个痛点推动了 Triton 化:

  1. 跨架构维护成本:AMD MI300、Intel Gaudi、Apple M 系列都要单独写一份
  2. 快速迭代的研究改动:Sliding window、ALiBi、chunked prefill 每改一次都要重新写 CUDA + 编译 wheel

vLLM V1(2025 年起)把 Triton attention backend 作为 AMD 与通用硬件的默认实现,文件在 vllm/v1/attention/backends/triton_attn.py

来源:vLLM Triton Backend Deep DiveEnabling vLLM V1 on AMD GPUs With Triton

13.1.3 三阶段演进:单 kernel → 拆分 → 重新合并

Phase 1:单 kernel 同时服务 prefill + decode

最直接的写法:把 grid 设为 (tokens_in_batch, num_query_heads, query_seq_block),每个 program 处理 [BLOCK_M, BLOCK_DMODEL] 的 Q 块对所有 KV blocks 的 attention。

python
@triton.jit
def unified_attention_kernel(
    Q_ptr, K_cache_ptr, V_cache_ptr, Out_ptr,
    block_tables_ptr,           # [num_seqs, max_blocks_per_seq]
    seq_lens_ptr,               # [num_seqs]
    num_heads, num_kv_heads, head_size,
    BLOCK_SIZE: tl.constexpr,   # KV cache 物理块大小(如 16)
    BLOCK_M: tl.constexpr,      # query token 维度的 tile
    BLOCK_DMODEL: tl.constexpr,
):
    seq_idx = tl.program_id(0)
    head_idx = tl.program_id(1)
    q_blk_idx = tl.program_id(2)

    seq_len = tl.load(seq_lens_ptr + seq_idx)
    block_table_offset = seq_idx * MAX_BLOCKS_PER_SEQ
    # 逐 physical block 取 K/V,做 online softmax
    ...

问题:decode 阶段 query_len = 1,但 grid 第 3 维仍按 BLOCK_M 切分;多数 program 立刻 mask 退出,SM 占用率掉到 30%

Phase 2:拆成两个 kernel

  • prefix_prefill_kernelquery_len > 1 时使用,含 chunked prefill
  • paged_attention_2d_kernelquery_len == 1 时使用,grid 改成 (num_seqs, num_query_heads),每个 program 处理 1 个 query token 对整段历史 KV 的 attention,BLOCK_N = 16 对齐 vLLM 默认 page size
python
@triton.jit
def paged_attention_2d_kernel(
    Q_ptr, K_cache_ptr, V_cache_ptr, Out_ptr,
    block_tables_ptr, seq_lens_ptr,
    scale,
    num_query_heads, num_kv_heads, head_size,
    BLOCK_SIZE: tl.constexpr,        # 16
    BLOCK_DMODEL: tl.constexpr,
):
    seq_idx = tl.program_id(0)
    head_idx = tl.program_id(1)

    seq_len = tl.load(seq_lens_ptr + seq_idx)
    num_blocks = tl.cdiv(seq_len, BLOCK_SIZE)

    # 单个 query token:装进寄存器
    q = tl.load(Q_ptr + seq_idx * num_query_heads * head_size
                       + head_idx * head_size + tl.arange(0, BLOCK_DMODEL))
    q = q * scale

    # online softmax 累加器
    m_i = -float("inf")
    l_i = 0.0
    acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)

    kv_head_idx = head_idx // (num_query_heads // num_kv_heads)  # GQA

    for blk_idx in range(num_blocks):
        # 查 block_table 拿到物理块号
        physical_blk = tl.load(block_tables_ptr
                               + seq_idx * MAX_BLOCKS_PER_SEQ + blk_idx)
        # 装入这个 block 的 K (BLOCK_SIZE × head_size)
        k_offs = (physical_blk * num_kv_heads * BLOCK_SIZE * head_size
                  + kv_head_idx * BLOCK_SIZE * head_size
                  + tl.arange(0, BLOCK_SIZE)[:, None] * head_size
                  + tl.arange(0, BLOCK_DMODEL)[None, :])
        k = tl.load(K_cache_ptr + k_offs)

        # 计算 q·K^T(注意是单 query token 对 BLOCK_SIZE 个 key)
        s = tl.sum(q[None, :] * k, axis=1)
        # mask 超出 seq_len 的位置
        tokens_in_blk = tl.arange(0, BLOCK_SIZE) + blk_idx * BLOCK_SIZE
        s = tl.where(tokens_in_blk < seq_len, s, -float("inf"))

        # online softmax 更新
        m_new = tl.maximum(m_i, tl.max(s, axis=0))
        alpha = tl.exp(m_i - m_new)
        p = tl.exp(s - m_new)

        # 装入 V,更新 acc
        v = tl.load(V_cache_ptr + k_offs)
        acc = acc * alpha + tl.sum(p[:, None] * v, axis=0)
        l_i = l_i * alpha + tl.sum(p, axis=0)
        m_i = m_new

    acc = acc / l_i
    tl.store(Out_ptr + seq_idx * num_query_heads * head_size
                     + head_idx * head_size + tl.arange(0, BLOCK_DMODEL),
             acc.to(tl.float16))

实测:拆分后 throughput +3.7×(来源同上)。

Phase 3:又合回单 kernel(unified kernel)

维护两份代码代价巨大 —— bug 修复要同步、autotune 配置要双份、行为差异容易漏。vLLM 当前主线正在把两条路再次统一回单 kernel,关键技巧:用 tl.constexprIS_DECODE 分支让编译器在编译期消除 dead code,runtime 仍是两套机器码源码只一份

13.1.4 Prefix Caching 的实现细节

vLLM V1 默认开启 Automatic Prefix Caching (APC)

text
请求 1: "You are a helpful assistant. Translate to French: hello"
请求 2: "You are a helpful assistant. Summarize this text: ..."
       └───── 共享前缀 ─────┘

实现路径:

  1. 块级 hash:对每个 KV block,把 token IDs 与上一块的 hash 做链式哈希(hash(prev_hash, tokens_in_this_block)
  2. 命中查询:新请求 prefill 时按块查 hash 表,命中则直接复用物理块 + bump refcount
  3. 驱逐:LRU on blocks,未引用且最旧的优先

代码:vllm/v1/core/kv_cache_utils.pyBlockHashTypehash_block_tokens

chunked prefill 的已知限制

当 APC 与 chunked prefill 同时启用时,只有 prompt 的第一个 chunk 享受 prefix cacheGH #7883)。原因:后续 chunk 的 hash 计算依赖前一 chunk 已经写入 KV cache 的状态,而该状态被 chunked 切散,链式 hash 容易错位。修复方案在讨论中。

13.1.5 取走的工程模式

text
1. 物理 KV 与逻辑 KV 分离 → block_table 间接寻址
2. 写一份 Triton 顶替多份 CUDA → 跨架构维护成本骤降
3. 用 grid 维度形态匹配 workload 形态(prefill vs decode)→ 占用率
4. tl.constexpr 分支统一 source、分裂 binary → 既减代码又不损性能
5. 块级 hash 实现 token 级缓存复用 → 无锁、O(1) 查询

13.2 SGLang RadixAttention 与 Triton 量化栈

13.2.1 RadixAttention:从分页到字典树

SGLang 提出的 RadixAttention 与 vLLM PagedAttention 解决的问题略有不同:

维度vLLM (APC)SGLang (RadixAttention)
数据结构块 hash 表Radix Tree(压缩前缀树)
粒度Block-level(16 token 块)Token-level(不依赖块对齐)
共享发现仅命中链式块 hash自动找最长公共前缀
Fork 支持原生(agent A/B 分支共享根)
驱逐LRU on blocksLRU on leaf nodes
最佳场景静态 system prompt多轮对话、Tree-of-Thought、Agent

举例:

text
请求 A: [System prompt 100 tokens] + "翻译 hello"
请求 B: [System prompt 100 tokens] + "解释 hello"
请求 C: [System prompt 100 tokens] + "翻译 world"

Radix Tree:
  root ──[System prompt 100]── ┬── "翻译 " ──┬── "hello"  (req A)
                                │              └── "world"  (req C)
                                └── "解释 hello"            (req B)

任意公共前缀(即使不是 16 倍数)都能复用,且新 fork 的请求几乎零启动开销

源码:sglang/srt/mem_cache/radix_cache.py,变种 SWARadixCache(滑动窗口)、ChunkCacheMambaRadixCache

论文:SGLang: Efficient Execution of Structured Language Model Programs (NeurIPS 2024)LMSYS Blog

13.2.2 SGLang 的 Triton kernel 分布

SGLang 在 NVIDIA 上默认走 FlashInfer + FlashAttention(CUDA 实现),但有大量 Triton kernel 在以下场景使用:

路径后端备注
NVIDIA attention 主力FlashInferC++/CUDA
AMD attentionTriton共享 vLLM 同款 backend
FP8 GEMM 默认CUTLASSC++
FP8 GEMM Triton 可选Triton--fp8-gemm-backend triton
W8A8 量化 / dequantTritonw8a8_fp8, fp8_block_quant
MoE expert routingTritonfused_moe_kernel
Sampling kernelTritontop-k/top-p with temperature

来源:SGLang quantization 文档

13.2.3 fused_moe_kernel 模式解析

MoE 推理的核心是 grouped GEMM:把每个 token 路由到 top-k 个 expert,对应不同的 (M_e, K) × (K, N) 矩阵相乘。SGLang 的 Triton 实现关键设计:

python
@triton.autotune(configs=MOE_CONFIGS, key=["N", "K", "top_k"])
@triton.jit
def fused_moe_kernel(
    A_ptr, B_ptr, C_ptr,
    topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr,
    num_tokens_post_padded_ptr,
    N, K, EM, num_valid_tokens,
    stride_am, stride_ak, stride_be, stride_bk, stride_bn,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
    MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr,
):
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    # super-grouping 提升 L2 局部性
    pid_m, pid_n = swizzle(pid, num_pid_m, num_pid_n, GROUP_SIZE_M)

    num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
    if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
        return  # padding 区域,全跳过

    # 关键 trick 1:sorted_token_ids 已按 expert 排序,连续 BLOCK_SIZE_M 个 token 同属一 expert
    offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
    token_mask = offs_token < num_valid_tokens

    # 关键 trick 2:用 expert_ids 找当前 tile 对应的 B 矩阵
    off_experts = tl.load(expert_ids_ptr + pid_m)

    # 之后是标准 GEMM 主循环 + epilogue 乘 routed weight
    ...

要点:

  • 排序+padding 把 grouped GEMM 化简为单 kernel:所有同 expert 的 token 连续摆放,每个 BLOCK_M tile 自然命中同一 expert
  • num_tokens_post_padded 记录 padding 后的总 token 数,让 grid 一次性覆盖所有 expert
  • MUL_ROUTED_WEIGHT: tl.constexpr 编译期分支,下行/上行 FFN 用不同二进制(下行要乘 router gate,上行不乘)

这套模式被 Unsloth、Megatron、PyTorch grouped_mm 反复复用,是 MoE Triton kernel 的事实标准。

13.2.4 FP8 block-scaled GEMM

DeepSeek-V3 流行的 block-scaled FP8 在 SGLang 的 Triton 实现:

  • Activation 按 [1, 128] 分组、weight 按 [128, 128] 分组,每组一个 FP32 scale
  • 主 GEMM 在 Tensor Core 上做 FP8 × FP8 → FP32
  • scale 在 K 维内层乘进 accumulator:每 BLOCK_K = 128 步乘一次对应的 (a_scale, b_scale)
python
for k in range(0, tl.cdiv(K, BLOCK_K)):
    a = tl.load(a_fp8_ptr + ...)            # FP8
    b = tl.load(b_fp8_ptr + ...)            # FP8
    a_s = tl.load(a_scale_ptr + k)          # FP32, [BLOCK_M, 1]
    b_s = tl.load(b_scale_ptr + k * stride) # FP32, [1, BLOCK_N]
    acc += tl.dot(a, b, out_dtype=tl.float32) * (a_s * b_s)

Dequant overhead 陷阱

* (a_s * b_s)CUDA Core(FP32)而非 Tensor Core 上执行。H100 FP32 CUDA Core 仅有 FP8 Tensor Core 的 1.6% 算力,单次 dequant 等价于 ~60 个 FP8 MAC 的代价。深思的工程师会让 BLOCK_K 取较大值(≥128),把 dequant 频率压到最低;极端场景把所有 scale 收尾时一次性应用(两级 scaling),但这要求权重 fine-grain 量化误差受控。详见 13.5 节。


13.3 PyTorch Inductor 的 Triton 代码生成

13.3.1 torch.compile 的全栈视角

text
Python source

   ▼  TorchDynamo(字节码级符号执行)
FX Graph (高层算子)

   ▼  AOTAutograd(前向 + 反向联合捕获)
Joint Forward+Backward FX Graph

   ▼  Inductor lowering(torch/_inductor/lowering.py)
Inductor IR (pointwise / reduction / GEMM 三类节点)

   ▼  Inductor fusion pass
Fused Inductor IR

   ├─→ GEMM/Linear        → CUTLASS / cuBLAS / cuBLASLt
   ├─→ Pointwise/Reduction → ★ Triton codegen ★
   └─→ Unsupported          → ATen / 直接调 CUDA

入口:torch/_inductor/compile_fx.py,Triton codegen 在 torch/_inductor/codegen/triton.py

详细设计文档:TorchInductor: a PyTorch-native Compiler

13.3.2 一个最小示例:fused add + relu

python
import torch
def f(x, y):
    return torch.relu(x + y)

compiled = torch.compile(f)
x = torch.randn(1024, 1024, device="cuda")
y = torch.randn(1024, 1024, device="cuda")
compiled(x, y)

设置 TORCH_LOGS="output_code" 后 Inductor 生成(简化):

python
@triton.jit
def triton_poi_fused_add_relu_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
                                 XBLOCK: tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    xmask = xindex < xnumel

    tmp0 = tl.load(in_ptr0 + xindex, xmask)
    tmp1 = tl.load(in_ptr1 + xindex, xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tl.full([1], 0, tl.float32)
    tmp4 = triton.language.maximum(tmp2, tmp3)
    tl.store(out_ptr0 + xindex, tmp4, xmask)

命名约定揭示了类型:

前缀含义典型 op
triton_poi_*pointwise kerneladd, mul, relu, sigmoid
triton_red_*reduction kernelsum, mean, max
triton_per_*persistent reduction单 tile 装得下整行的 softmax
triton_tem_*template kernelmatmul template-based

13.3.3 调试 Inductor 输出的三种武器

武器 1:TORCH_LOGS

bash
TORCH_LOGS="output_code" python my_model.py
# 等价 Python:
torch._logging.set_logs(output_code=True)

直接打到 stderr,看每个 fused kernel 的 Triton 源码。

武器 2:TORCH_COMPILE_DEBUG

bash
TORCH_COMPILE_DEBUG=1 python my_model.py

在执行目录创建 torch_compile_debug/run_*/ 子目录,含:

text
torch_compile_debug/run_2026_05_28_10_30_00/
├── torchinductor/
│   └── model__0_inference_0/
│       ├── fx_graph_readable.py      # 人类可读的 FX graph
│       ├── fx_graph_runnable.py      # 可独立运行的复现脚本
│       ├── fx_graph_transformed.py   # 经过 lowering 后
│       ├── ir_pre_fusion.txt         # Inductor IR(融合前)
│       ├── ir_post_fusion.txt        # 融合后
│       └── output_code.py            # ★ 最终 Triton kernel
└── dynamo/
    └── ...

ir_pre_fusion.txt vs ir_post_fusion.txt 对比能直观看到哪些 op 被融合掉了,进而判断是否触发预期融合。

武器 3:TORCHINDUCTOR_CACHE_DIR

bash
TORCHINDUCTOR_CACHE_DIR=/tmp/my_cache python my_model.py

控制编译缓存位置,方便清理或在多个项目间隔离。

教程:TORCH_LOGS recipe

13.3.4 Inductor 融合策略关键规则

torch/_inductor/scheduler.py 后能总结出几条Inductor 决定是否融合两个节点的核心规则:

规则例子
必须共享至少一个 read,或一个的输出是另一个的输入add(x,y) + relu(out) ✅;add(x,y) + mul(a,b)
形状必须兼容广播[B,N] + [B,N] ✅;[B,N] + [N,B] ❌(需 transpose)
不能跨越Reduction 边界(除 persistent reduction)sumrelu 通常各为一个 kernel
buffer 复用约束:被多个下游消费的中间 buffer 必须物化写出 extern_kernels.copy_ 而非融合

让融合更激进的实用技巧

  1. .contiguous() 放到管道最早位置,避免融合被 stride mismatch 打断
  2. 避免 in-place optorch.compile 函数体内(如 x.add_(y)),会阻断融合
  3. 对 reduction-heavy 模型,设 torch._inductor.config.triton.unique_kernel_names = True 后看命名能否合并

13.3.5 Inductor 自动调优

torch/_inductor/runtime/triton_heuristics.py 中的 triton_config_with_settings() 是 Inductor 的 mini-autotune:

python
# 简化版伪代码
def triton_config(size_hints, x, y=None, z=None, num_stages=1):
    cfg = {"XBLOCK": x}
    if y is not None: cfg["YBLOCK"] = y
    if z is not None: cfg["ZBLOCK"] = z
    return triton.Config(
        cfg,
        num_warps=num_warps_heuristic(size_hints),
        num_stages=num_stages,
    )

它会枚举 XBLOCK ∈ {64, 128, 256, 512, 1024, 2048}num_warps ∈ {2, 4, 8},跑一次 benchmark 选最优。结果缓存到 ~/.cache/torch_inductor/,避免每次重跑。

不要在长 benchmark 中信第一次结果

Inductor 第一次 compile 包含 autotune 时间,远慢于稳态。torch.compile(..., mode="max-autotune") 后必须 warm up 至少 3 次才能拿到代表性数字。


13.4 Unsloth / Axolotl:LoRA 微调的核函数艺术

13.4.1 LoRA 微调的瓶颈在哪里

LoRA 把权重更新拆为 W' = W + α·B·A(A, B 是低秩矩阵)。朴素 PyTorch 实现:

python
def lora_forward(x, W, A, B, scale):
    return x @ W.T + scale * (x @ A.T @ B.T)

PyTorch autograd 默认会保存所有中间张量用于 backward:

  • x @ A.T:shape [B, r],r 通常 16-64,开销小
  • x @ W.T:shape [B, out],开销大但必存
  • forward activation 必存到 backward 调用

实际显存占用经常被中间 activation 拉爆。Unsloth 的解法:自定义 torch.autograd.Function,手工选择哪些张量存、哪些重算

13.4.2 Unsloth 的五大核心优化

优化 1:Fused activation kernel(SwiGLU / GEGLU)

LLaMA 风格的 MLP 是 down(silu(gate(x)) * up(x))。Unsloth 把 silu * up 融合进单 Triton kernel,省掉中间张量:

python
@triton.jit
def fused_swiglu_kernel(gate_ptr, up_ptr, out_ptr, n_elements,
                        BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < n_elements
    g = tl.load(gate_ptr + offs, mask=mask)
    u = tl.load(up_ptr + offs, mask=mask)
    # silu(g) * u = g * sigmoid(g) * u
    silu_g = g * tl.sigmoid(g)
    tl.store(out_ptr + offs, silu_g * u, mask=mask)

backward 同样融合:

python
@triton.jit
def fused_swiglu_backward_kernel(grad_out_ptr, gate_ptr, up_ptr,
                                  grad_gate_ptr, grad_up_ptr, n_elements,
                                  BLOCK: tl.constexpr):
    ...
    sig_g = tl.sigmoid(g)
    silu_g = g * sig_g
    # d(silu_g * u)/du = silu_g
    grad_u = grad_out * silu_g
    # d(silu_g * u)/dg = u * (sig_g + g * sig_g * (1 - sig_g))
    grad_g = grad_out * u * (sig_g * (1.0 + g * (1.0 - sig_g)))
    tl.store(grad_gate_ptr + offs, grad_g, mask=mask)
    tl.store(grad_up_ptr + offs, grad_u, mask=mask)

关键设计:backward 只存 g 和 u,不存 silu(g) * u

重新算 sigmoid(g) 比存一个 [B*S, hidden] 的中间张量便宜得多。Recompute 显存换算力在 SwiGLU 这种 fused activation 上几乎免费。

优化 2:LoRA MLP 算子级融合

Unsloth 写了 LoRA_MLP.forward() —— base + LoRA 的 forward 与 backward 共享中间张量

python
class FastLoraMLPForward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, W_gate, W_up, W_down,
                gate_A, gate_B, up_A, up_B, down_A, down_B, scale):
        # 一次性算完 base
        gate_out = x @ W_gate.T + scale * (x @ gate_A.T @ gate_B.T)
        up_out   = x @ W_up.T   + scale * (x @ up_A.T   @ up_B.T)
        # SwiGLU 融合
        act = fused_swiglu(gate_out, up_out)
        out = act @ W_down.T + scale * (act @ down_A.T @ down_B.T)

        # save_for_backward 只存最少必要张量
        ctx.save_for_backward(x, gate_out, up_out, act,
                              W_gate, W_up, W_down,
                              gate_A, gate_B, up_A, up_B, down_A, down_B)
        ctx.scale = scale
        return out

    @staticmethod
    def backward(ctx, grad_out):
        # 一次 backward 把所有 gradient 算完,避免 PyTorch autograd 的"账本"开销
        ...

PyTorch 默认 autograd 会为每个 @ 操作都记下 op + saved tensors,bookkeeping 本身就消耗 10-15% 时间。手写 Function 直接绕过。

优化 3:关联律重排

数学上 (B @ A) @ x ≡ B @ (A @ x),但显存差距巨大:

  • B @ A[out, r] @ [r, in] = [out, in]和原 weight 一样大
  • A @ x[r, in] @ [in, batch] = [r, batch],r 通常 ≤64

Unsloth 强制走右结合:先 A @ x[r, batch],再 B @ (A @ x)[out, batch]显存省 3-5×

优化 4:MoE grouped_mm

UNSLOTH_MOE_BACKEND=grouped_mm 后启用专用 Triton kernel,借鉴 13.2.3 的 sorted+padded 模式。Unsloth 声称在 gpt-oss / Qwen3-MoE / DeepSeek R1/V3 / GLM 上提供 12× 训练速度、35% VRAM 节省

来源:Unsloth MoE blog

优化 5:4-bit / 8-bit 原生集成

LoRA base weight 用 NF4 / Int8 量化存储,Triton kernel 在内层主循环里在线 dequant:

python
# 简化示意
for k_blk in range(K_blocks):
    w_q = tl.load(W_quant + ...)       # int4 / uint8
    w_scale = tl.load(W_scale + ...)   # FP16, 每 block 一个
    w = w_q.to(tl.float16) * w_scale   # dequant
    acc += tl.dot(x, w, acc)

13.4.3 Axolotl:开源的 Unsloth-like 路径

Unsloth 只把 LoRA 路径开源(issue #1038),全参数 fine-tune 优化是商业版。Axolotl 基于 Unsloth 思路独立实现:

  • axolotl/integrations/lora_kernels/:fused SwiGLU/GEGLU forward+backward
  • 自定义 torch.autograd.Function 处理 LoRA MLP/Attention 算子融合
  • 同时支持 DDP / DeepSpeed / FSDP2

实测在 H100 / B200 上 LoRA 微调 2-5× 加速、80% 显存节省

来源:Axolotl LoRA OptimizationsAxolotl 工程博客

13.4.4 复制到自己项目的实战清单

text
1. 任何 element-wise activation + 下一步 elementwise,立即写 fused kernel
2. LoRA 路径必走 (A @ x) 优先的关联律重排
3. 自定义 autograd Function 时严格 save_for_backward 最小集
4. 中间张量能 recompute 的尽量 recompute(H100 SMEM 带宽爆炸 vs HBM 紧张)
5. 量化权重在内层主循环 dequant,不要预先物化整个 dequant 矩阵

13.5 FP8 / INT8 量化核函数编写模式

13.5.1 三种 scaling 粒度

粒度scale 数量精度损失dequant 开销代表实现
per-tensor1极低(kernel 尾)TransformerEngine 早期
per-token (activation) / per-channel (weight)M / NvLLM w8a8_fp8
per-block (e.g. 128×128)M/128 × K/128DeepSeek-V3, FP8 GEMM

13.5.2 PyTorch 官方的 GridQuant + GEMM 双 kernel 路径

PyTorch blog 给出的 H100 FP8 GEMM 经典 pipeline:

来源:PyTorch — Accelerating GEMMs with Triton

Phase 1:2D block quantization (256×256)

python
@triton.jit
def grid_quant_fp8_kernel(A_ptr, A_fp8_ptr, A_scale_ptr,
                          M, N, FP8_MAX: tl.constexpr,
                          BLOCK_M: tl.constexpr,    # 256
                          BLOCK_N: tl.constexpr,    # 256
                          SUB_TILE: tl.constexpr):  # 32
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # grid-stride 求 abs_max
    abs_max = 0.0
    for sub_m in range(0, BLOCK_M, SUB_TILE):
        for sub_n in range(0, BLOCK_N, SUB_TILE):
            offs_m = pid_m * BLOCK_M + sub_m + tl.arange(0, SUB_TILE)
            offs_n = pid_n * BLOCK_N + sub_n + tl.arange(0, SUB_TILE)
            a = tl.load(A_ptr + offs_m[:, None] * N + offs_n[None, :])
            abs_max = tl.maximum(abs_max, tl.max(tl.abs(a)))

    scale = abs_max / FP8_MAX
    tl.store(A_scale_ptr + pid_m * tl.cdiv(N, BLOCK_N) + pid_n, scale)

    # 第二遍:写出 FP8
    inv_scale = 1.0 / scale
    for sub_m in range(0, BLOCK_M, SUB_TILE):
        for sub_n in range(0, BLOCK_N, SUB_TILE):
            ...
            a_fp8 = (a * inv_scale).to(tl.float8e4nv)
            tl.store(A_fp8_ptr + ..., a_fp8)

Phase 2:FP8 Tensor Core GEMM + epilogue descale

python
@triton.jit
def fp8_gemm_kernel(A_fp8_ptr, B_fp8_ptr, C_ptr,
                    A_scale_ptr, B_scale_ptr,
                    M, N, K, ...):
    ...
    acc = tl.zeros([BLOCK_M, BLOCK_N], tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(A_fp8_ptr + ...)
        b = tl.load(B_fp8_ptr + ...)
        acc += tl.dot(a, b)                 # FP8 × FP8 → FP32

    # 在 epilogue 一次性 dequant:每 256×256 块一个 scale
    a_s = tl.load(A_scale_ptr + ...)        # 标量
    b_s = tl.load(B_scale_ptr + ...)        # 标量
    acc = acc * a_s * b_s
    tl.store(C_ptr + ..., acc.to(tl.bfloat16))

为什么 256×256 而不是 128×128

  • 块越大 → scale 数量越少 → epilogue dequant 开销越低
  • 块越大 → 异常 outlier 越容易拉爆 scale → 精度损失越高
  • 256×256 是 H100 FP8 在 LLM 场景上的甜点;DeepSeek-V3 选 128×128 是因为推理用,精度更敏感

13.5.3 vLLM 的 per-token-group dynamic FP8

vllm/model_executor/layers/quantization/utils/fp8_utils.py

python
@triton.jit
def _per_token_group_quant_fp8(
    y_ptr, y_q_ptr, y_s_ptr,
    group_size, y_num_columns, y_row_stride,
    eps, fp8_min, fp8_max,
    use_ue8m0: tl.constexpr,  # E8M0 power-of-2 scale 用于 microscaling
    BLOCK: tl.constexpr,
):
    pid = tl.program_id(0)              # 每个 program 处理一个 group
    row = pid // num_groups_per_row
    col_group = pid % num_groups_per_row

    offs = (row * y_row_stride
            + col_group * group_size
            + tl.arange(0, BLOCK))
    mask = tl.arange(0, BLOCK) < group_size

    y = tl.load(y_ptr + offs, mask=mask)
    abs_max = tl.max(tl.abs(y))
    scale = tl.maximum(abs_max / fp8_max, eps)

    if use_ue8m0:
        # MXFP8 风格:scale 量化为 power-of-2(E8M0)
        scale = tl.exp2(tl.ceil(tl.log2(scale)))

    y_q = (y / scale).to(tl.float8e4nv)
    tl.store(y_q_ptr + offs, y_q, mask=mask)
    tl.store(y_s_ptr + pid, scale)

use_ue8m0=True 启用 MXFP8 微缩放 —— scale 本身用 E8M0(仅指数)表示,硬件可在 Tensor Core 内自动应用,消除 CUDA Core dequant 开销。这是 Blackwell 的关键新能力。

13.5.4 Dequant Overhead 与两级 scaling

复述 13.2.4 的核心:H100 FP32 CUDA Core 仅 FP8 Tensor Core 1.6% 算力,单次 dequant 约 60 个 FP8 MAC 代价。在 K 维 64 次循环中每次 dequant 一遍,会完全抹平 FP8 vs BF16 的增益

两级 scaling(MOSS、COAT 等论文方案):

text
fine-grain scale:每 group 一个 E8M0(power-of-2)
coarse scale :  每 tensor 一个 FP32

主循环:acc += tl.dot(a_fp8, b_fp8)  # 不做 dequant
        (E8M0 group scale 由 Tensor Core 自动应用,零额外指令)
epilogue:acc * a_tensor_scale * b_tensor_scale  # 一次性

论文:MOSS (arXiv 2511.05811)COAT (arXiv 2410.19313)

13.5.5 INT8 与 FP8 的关键差异

维度INT8FP8 (E4M3/E5M2)
数值表示对称 / 非对称 整数浮点(有 exponent)
动态范围256 个等距值E4M3: ±448,更宽
异常值容忍差(需 SmoothQuant 等手段)较好(exponent 覆盖)
Quant 公式round(x / scale) + zero_point(x / scale).to(fp8)
Tensor CoreA100 / H100 / B200 都支持H100+
典型应用CNN、传统 NLPLLM 推理

INT8 Triton kernel 多一层 zero_point 处理 + asymmetric 校正:

python
q = ((x - x_min) / scale).to(tl.int8)
deq = q.to(tl.float32) * scale + x_min

Speechmatics 的 INT8 落地经验:Fast and Accurate GPU Quantization for Transformers

13.5.6 vLLM 的 online dynamic FP8 副作用

--quantization fp8 默认开启 activation 动态量化:每次 forward 实时算 abs_max,额外一次 reduce + scale 计算开销在小 batch 时占比可达 8-15%,长输入/高吞吐场景才回本。

何时选择哪种量化路径

  • 小 batch 低延迟(chatbot):静态量化(offline calibration),避免 activation 量化的 launch 开销
  • 大 batch 高吞吐(API serving):dynamic FP8 + 大 BLOCK_K,dequant 摊薄
  • 训练:MXFP8 / FP4(要求 Blackwell),fine-grain scale 抗 outlier

13.6 反向传播核函数实战

13.6.1 经典模板:torch.autograd.Function + Triton

python
import torch
import triton
import triton.language as tl

# 假设 _attn_fwd_triton 和 _attn_bwd_triton 是写好的 Triton kernel
class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, sm_scale, causal):
        o, lse = _attn_fwd_triton(q, k, v, sm_scale, causal)
        ctx.save_for_backward(q, k, v, o, lse)
        ctx.sm_scale = sm_scale
        ctx.causal = causal
        return o

    @staticmethod
    def backward(ctx, do):
        q, k, v, o, lse = ctx.saved_tensors
        dq, dk, dv = _attn_bwd_triton(q, k, v, o, lse, do,
                                       ctx.sm_scale, ctx.causal)
        # 返回元组数量 = forward 入参数;不要梯度的位置返回 None
        return dq, dk, dv, None, None

flash_attn = FlashAttnFunc.apply

要点:

  1. forward 内部调 Triton kernel,PyTorch autograd 不会自动追踪这些计算
  2. save_for_backward 仅保存 backward 必需张量(FA 只存 lse,重算 P)
  3. 返回元组长度与 forward 入参数严格相等
  4. torch.autograd.gradcheck 在 FP64 小规模上做数值校验

13.6.2 推荐路径:torch.library.triton_op + register_autograd

PyTorch 2.4+ 起的新 API,完全兼容 torch.compile 和 AOTInductor,避免 autograd.Function 与 dynamo 的若干 composability 陷阱:

python
import torch
import triton

@torch.library.triton_op("mylib::flash_attn", mutates_args=())
def flash_attn(q, k, v, sm_scale: float, causal: bool):
    return torch.library.wrap_triton(_attn_fwd_triton)(q, k, v, sm_scale, causal)

def setup_context(ctx, inputs, output):
    q, k, v, sm_scale, causal = inputs
    o, lse = output
    ctx.save_for_backward(q, k, v, o, lse)
    ctx.sm_scale = sm_scale
    ctx.causal = causal

def backward(ctx, grad_o, grad_lse):
    q, k, v, o, lse = ctx.saved_tensors
    dq, dk, dv = torch.library.wrap_triton(_attn_bwd_triton)(
        q, k, v, o, lse, grad_o, ctx.sm_scale, ctx.causal
    )
    return dq, dk, dv, None, None

flash_attn.register_autograd(backward, setup_context=setup_context)

为什么用 triton_op 而不是 autograd.Function

  • torch.compile 友好:dynamo 能正确识别为不可分解 op,反向也能被 aot_autograd 追踪
  • AOTInductor 支持:可以编译成 .so 部署到无 Python 的环境
  • 可序列化:进入 fx graph 后能被 ONNX、TorchScript 等导出

教程:PyTorch — User Defined Triton Kernel with torch.compile

13.6.3 Backward Kernel 三大设计要点

要点 1:Atomic Add 用于梯度归约

多个 program instance 写同一个 dW 位置时必须用 atomic:

python
tl.atomic_add(dW_ptr + offs, partial_grad, mask=mask)

默认 sem="acq_rel", scope="gpu"跨 SM 写scope="sys"(更慢但正确)。

H100 + Triton 3.3 的已知 atomic 问题

GH #7402 报告 tl.atomic_add 在某些跨线程模式下被报告语义不正确。Workaround:先在 SMEM 内用 tl.reduce 局部归约,再 atomic 写出,**reduce 输入量降低 32×**的同时绕开 bug。

要点 2:Recompute vs Save 的权衡

张量保存代价(HBM)重算代价选择
FA 的 S = QK^T[B, H, M, N],超大一次 GEMM重算(只存 lse)
FA 的 P = softmax(S)同上exp + 归一重算
LayerNorm 的 mean, var[B, S],小一次 reduce可选(看 register pressure)
MatMul 的 input视形状不可能必存

经典原则:只存 reduce 后的统计量(lse, mean, var),原始大张量必须重算

要点 3:Tiling 方向的选择

FA forward 按 Q 切分 program(每个 program 处理一个 Q tile,循环遍历所有 K, V)。

FA backward 的精妙之处:

  • dK, dVK/V 切分(每个 program 处理一个 K tile,循环遍历所有 Q)—— 无需 atomic
  • dQ 按 Q 切,但和 dK/dV 切分方向冲突,需要 atomic_add 累加;FlashAttention v2/v3 选择 dQ 用 atomic 写回 HBM

13.6.4 FA Backward Kernel 简化示例

python
@triton.jit
def attn_bwd_kv_kernel(
    Q, K, V, O, lse, dO,                    # 输入
    dQ, dK, dV,                             # 输出
    sm_scale, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
    BLOCK_DMODEL: tl.constexpr, CAUSAL: tl.constexpr,
):
    """每个 program 处理一个 K/V tile([BLOCK_N, D]),算 dK[n_blk], dV[n_blk]"""
    start_n = tl.program_id(0)
    off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)

    k = tl.load(K + off_n[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :])
    v = tl.load(V + off_n[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :])

    dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
    dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)

    # 遍历所有 Q tile
    for start_m in range(0, N_CTX, BLOCK_M):
        off_m = start_m + tl.arange(0, BLOCK_M)
        q = tl.load(Q + off_m[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :])
        do = tl.load(dO + off_m[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :])
        l = tl.load(lse + off_m)  # [BLOCK_M]

        # 重算 S, P(FA 不存 P)
        s = tl.dot(q, tl.trans(k)) * sm_scale
        if CAUSAL:
            s = tl.where(off_m[:, None] >= off_n[None, :], s, -float("inf"))
        p = tl.exp(s - l[:, None])  # softmax

        # dV += P^T @ dO
        dv += tl.dot(tl.trans(p.to(do.dtype)), do)

        # dP = dO @ V^T
        dp = tl.dot(do, tl.trans(v))
        # dS = P * (dP - rowsum(P * dP))
        delta = tl.sum(p * dp, axis=1, keep_dims=True)
        ds = p * (dp - delta) * sm_scale

        # dK += dS^T @ Q
        dk += tl.dot(tl.trans(ds.to(q.dtype)), q)

        # dQ 需要 atomic 累加
        dq_partial = tl.dot(ds.to(k.dtype), k)
        tl.atomic_add(dQ + off_m[:, None] * D
                          + tl.arange(0, BLOCK_DMODEL)[None, :],
                      dq_partial)

    tl.store(dK + off_n[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :], dk)
    tl.store(dV + off_n[:, None] * D + tl.arange(0, BLOCK_DMODEL)[None, :], dv)

实战经验

  • 先写 forward + gradcheck,再写 backward:能 forward 跑对就成功一半
  • backward 单独测torch.autograd.gradcheck(my_func, (q.double(), k.double(), v.double()), eps=1e-6),关闭 causal、用 FP64
  • 大量 atomic 时考虑分 split-k:先把 dQ 写到 [NUM_SPLITS, M, D] 中间 buffer,再用单独 reduce kernel 合并 —— 比纯 atomic 更快

13.6.5 自动微分的未来(实验性)

Triton GH Discussion #6913 有 Iaroslav Elistratov 的 PoC:在 Triton IR 层做自动微分。

  • 对 FA2 forward 写 backward 可移除 300 行用户代码
  • 对 LayerNorm 移除 120 行
  • 还未合入主线,但路线明确

预计 Triton 3.5+ 会有 @triton.jit(autodiff=True) 装饰器选项,类似 JAX 的 jax.grad,自动生成 backward kernel。

相关:AlexDremov 的 Triton + autograd 实战


13.7 本章小结

把六个真实案例串成一条线,你会发现工业级 Triton 的"道法术":

(设计哲学):

  • 写一份 Triton 顶替多份 CUDA,跨架构维护成本骤降(vLLM, SGLang)
  • 物理 / 逻辑分离 + 间接寻址,把 OS 思想搬进 GPU 内存管理(PagedAttention)
  • 数学等价的算子重排能省 3-5× 显存(Unsloth LoRA 关联律)

(架构模式):

  • tl.constexpr 分支统一 source、分裂 binary(vLLM phase 3)
  • sorted + padded → 单 kernel grouped GEMM(SGLang fused_moe)
  • 两级 scaling 隔离 dequant overhead(MOSS / COAT FP8)
  • recompute 中间张量 + 只存 reduce 后统计量(FA backward)

(工程实践):

  • TORCH_COMPILE_DEBUG=1 + output_code.py 读懂 Inductor 在做什么
  • torch.library.triton_op 优先于 torch.autograd.Function
  • atomic_add 高 reduction 时先 SMEM 局部归约
  • LoRA 微调强制走 (A @ x) 优先的右结合
  • backward kernel 按 K/V 切(dK, dV 无需 atomic),dQ 用 atomic

关键对照表

系统核心创新Triton 角色可学到
vLLMPagedAttentionAMD 主力 / NVIDIA 备份间接寻址 + grid 形态切换
SGLangRadixAttention量化栈 / MoE / AMDsorted+padded grouped GEMM
PyTorch Inductorfusion + autotunepointwise/reduction 后端fusion 规则与调试技巧
UnslothLoRA 算子级融合全栈算子重排 + 自定义 autograd
MOSS / COAT两级 scaling FP8量化 kerneldequant overhead 控制
FlashAttentionrecompute backwardforward + backward 全套反向核函数设计

思考题

思考题 1:vLLM Phase 2 的 throughput 提升来源

vLLM 从单 kernel 改成 prefill/decode 双 kernel 后 throughput 提升 3.7×。请定量分析至少三个性能瓶颈的解除:grid 形态、寄存器占用、SMEM 利用率。如果给你一个 batch 包含 50% prefill (avg seq=512) + 50% decode (seq=1),单 kernel 在 H100 上 SM 占用率大约是多少?

参考思路
  • 单 kernel grid 按最大 query_len 切,decode 的 query_len=1 的 program 立刻 mask 退出,算力浪费 ~99%
  • 单 kernel BLOCK_M 必须足够大覆盖 prefill,导致 decode 时寄存器/SMEM 都被无效分配
  • 双 kernel 后 decode 用紧致的 2D grid (num_seqs, num_heads),每个 program 工作量饱满
  • 估算:单 kernel decode 部分 SM 占用率仅 1-5%;双 kernel 后 60-80%

思考题 2:SGLang fused_moe 的 padding 浪费

fused_moe_kernel 通过 padding 让每个 expert 的 token 数对齐 BLOCK_M。若 BLOCK_M=128、有 8 个 expert,最坏情况下 padding 浪费多少 token?给定 total_tokens = 4096, num_experts = 8 且路由完全均匀,浪费率是多少?路由极不均匀(90% token 路由到 1 个 expert)时呢?

思考题 3:Inductor 融合断点定位

写一段代码 def f(x): return torch.softmax(x @ W, dim=-1) + bias,用 TORCH_COMPILE_DEBUG=1ir_pre_fusion.txtir_post_fusion.txt。请回答:

  1. softmax 是被融合成 triton_per_* 还是 triton_red_*?决定因素是什么?
  2. @W 走的是 Triton 还是 CUTLASS?为什么 Inductor 默认在大 GEMM 上不用 Triton?
  3. 最终 + bias 能否融入 softmax kernel?还是单独成 kernel?

思考题 4:Unsloth LoRA 关联律的反例

(B @ A) @ x vs B @ (A @ x)out=11008, in=4096, r=16, batch=8192 上对比显存与时间。何时前者反而更快?(提示:思考极小 batch、极小 in/out 维度的情形)

思考题 5:FP8 两级 scaling 的精度边界

设 fine-grain group 是 128 个元素的 E8M0 scale(仅指数),coarse scale 是 per-tensor FP32。如果 weight tensor 的全局 max 是 100,但某个 128-token group 的局部 max 是 0.01,dequant 时实际表示精度损失多少 bit?为什么 DeepSeek-V3 用 [128, 128] 而不是 [1024, 1024] 的 block?

思考题 6:Backward kernel atomic 还是 split-k

FA backward 的 dQ 写回有两种实现:

  • A:tl.atomic_add 直接写最终位置
  • B:先写到 [NUM_SPLITS, M, D] 中间 buffer,再单独 reduce kernel 合并

BLOCK_M=64, BLOCK_N=128, NUM_HEADS=32, N_CTX=8192, NUM_SPLITS=N_CTX/BLOCK_N=64。请估算两种方案的:

  1. 中间 HBM 占用
  2. 跨 SM atomic 冲突频率
  3. 总 HBM 带宽消耗

什么情况下 A 更优?什么情况下 B?

参考思路
  • B 方案中间 buffer = 64 × 8192 × 128 × 4B = 256 MB,仅 dQ 已很重;A 方案零额外
  • A 的 atomic 冲突频率与 NUM_HEADS × NUM_K_BLOCKS 成正比;高时性能严重下降
  • NUM_K_BLOCKS 小(短序列)→ A 优;长序列 / 高 head 数 → B 优
  • FlashAttention v2 论文选 A 是因为长序列下 SMEM 容量限制 B 的中间 buffer

下一章我们将转向工具链 —— 当你的 kernel 跑得不够快时,如何用 NSight Compute、Triton 内置 profiler、PyTorch profiler 三件套定位真正的瓶颈,并把"看上去对"的微优化与"实测有效"的优化分开。

基于 MIT 协议发布