Skip to content

7. 算子融合

"做 GPU 优化,本质就是减少 HBM 往返。" 算子融合是这个原则最直接的体现——把多个连续算子合并成一个核函数,让中间结果只在寄存器 / SRAM 里流转。

本章会先回答"为什么 LLM 算子大多数都是 memory-bound",再用 Triton 改写 PyTorch 朴素 softmax 的完整案例展示融合的威力,最后总结常见的融合模式与性能收益。

7.1 为什么要做算子融合

7.1.1 一个直观的例子

假设你有一段 PyTorch 代码:

python
y = torch.exp(x)          # 核函数 1
z = y + 1.0               # 核函数 2
out = z / z.sum(dim=-1, keepdim=True)  # 核函数 3 + 4 (sum + div)

PyTorch eager 模式下,每个算子都是独立的 CUDA 核函数。每个核函数的执行都包含三步:

  1. 从 HBM 加载输入到寄存器 / SRAM
  2. 算(很快)
  3. 把结果写回 HBM

[M, N] 的张量,这 4 个核函数加起来:

  • HBM 读 ≈ 4 × M × N 元素
  • HBM 写 ≈ 4 × M × N 元素

但如果把它们融合成 1 个核函数:

  • HBM 读 = M × N(只读 x 一次)
  • HBM 写 = M × N(只写 out 一次)

带宽用量降到 1/4。对 memory-bound 算子,这就是 ~4× 加速。

7.1.2 LLM 算子的"内存反直觉"

直觉上,深度学习的瓶颈应该是矩阵乘法。但 Microsoft Research 的论文 Data Movement is All You Need (Ivanov et al., 2020) 给出了反直觉结论:

算子类型占 FLOPs 比例占运行时间比例
矩阵乘法(compute-bound)99.8%61%
Normalization / Element-wise(memory-bound)0.02%39%

剩下 39% 的运行时间被那些 FLOPs 极少的"小算子"吃掉,因为它们的 arithmetic intensity(FLOPs / byte)太低——GPU 大多数时间都在等数据。

怎么判断一个算子是不是 memory-bound

arithmetic intensity = FLOPs / bytes accessed。

  • softmax(x):FLOPs ≈ 5 × N,bytes ≈ 8 × N(fp32 读+写),AI ≈ 0.6 → memory-bound
  • matmul(M, N, K):FLOPs ≈ 2 × M × N × K,bytes ≈ 2 × (M×K + K×N + M×N),当 M=N=K=4096 时 AI ≈ 1365 → compute-bound

A100 的 "ridge point"(compute 与 memory 平衡点)大约是 AI=10。低于 10 的几乎一定 memory-bound,融合是首要优化手段。

7.2 Triton 实现融合的优势

为什么用 Triton 写融合,而不是 CUDA 或 torch.compile

7.2.1 比 CUDA 简洁 4 倍

LinkedIn 上 Leandro Lacerda 等人在《From Python to Bare Metal》对比中给出过实测数据:FlashAttention v1 的 Triton 实现约为 CUDA 实现的 1/4 行数,同时性能达到 CUDA 的 60~90%。原因:

  • Triton 自动处理 memory coalescing、共享内存分配、Tensor Core 调度
  • tl.dot(a, b, acc) 就触发 Tensor Core,不用手写 WMMA / MMA 指令
  • 不用管 warp shuffle 和 reduce 的细节

7.2.2 比 torch.compile 更可控

torch.compile / torch.jit.script 的自动融合能力是受限的:

  • 简单 element-wise chain 能融,复杂的(带 reduction、控制流)经常融不了
  • 融合后的核函数选什么分块尺寸、num_warps 完全黑盒
  • Triton 官方教程的 softmax 实测对比中,torch.jit.script 版本完全没有融合——速度和朴素 PyTorch eager 一样慢

7.2.3 中等抽象层级

Triton 的定位是 "比 CUDA 高一层,比 PyTorch 低一层"

  • 比 CUDA 高:用块而不是线程思考;不用关心 warp 同步、bank conflict
  • 比 PyTorch 低:能控制分块大小、循环展开、accumulator 数据类型

这个抽象层级正好适合写"自定义融合"——你想把任何几个算子捏成一个核函数都行,写起来又不至于像 CUDA 那样痛苦。

7.3 实战案例:融合 Softmax

完整代码见 examples/02_fused_softmax.py。本节我们一步步看融合是如何完成的。

7.3.1 朴素 PyTorch 实现的内存账

数值稳定的 softmax 标准实现:

python
def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    # x: [M, N]
    x_max = x.max(dim=1)[0]                 # 读 MN, 写 M
    z = x - x_max[:, None]                  # 读 MN+M, 写 MN
    numerator = torch.exp(z)                # 读 MN, 写 MN
    denominator = numerator.sum(dim=1)      # 读 MN, 写 M
    ret = numerator / denominator[:, None]  # 读 MN+M, 写 MN
    return ret

总计:读 5MN + 2M,写 3MN + 2M——光是中间张量就把 HBM 来回搬了 4 趟。

理论下限(融合后):读 MN + 写 MN = 2MN。理论加速比 8MN / 2MN = 4×

7.3.2 Triton 融合 softmax 核函数

python
import torch
import triton
import triton.language as tl


@triton.jit
def softmax_kernel(
    output_ptr,                # 输出张量首指针 [M, N]
    input_ptr,                 # 输入张量首指针 [M, N]
    input_row_stride,          # 输入每行步长
    output_row_stride,         # 输出每行步长
    n_rows,                    # M
    n_cols,                    # N
    BLOCK_SIZE: tl.constexpr,  # >= n_cols 的最小 2 的幂
):
    """每个程序实例负责一行(或循环消费多行,persistent 风格)。"""
    row_start = tl.program_id(axis=0)
    row_step = tl.num_programs(axis=0)

    for row_idx in tl.range(row_start, n_rows, row_step):
        # ---- 1) 定位本行 ----
        row_start_ptr = input_ptr + row_idx * input_row_stride
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets

        # ---- 2) 加载整行到 SRAM (共享内存),越界用 -inf 填充 ----
        # 为什么是 -inf?
        #   max(-inf, x) = x       → 不影响 max 计算
        #   exp(-inf - max) = 0    → 不贡献到 sum
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))

        # ---- 3) 全部融合在寄存器 / SRAM 内完成 ----
        row_minus_max = row - tl.max(row, axis=0)   # 数值稳定
        numerator = tl.exp(row_minus_max)            # tl.exp 类似 __expf
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator

        # ---- 4) 写回 ----
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)


def softmax(x: torch.Tensor) -> torch.Tensor:
    """沿最后一维 (dim=1) 做 softmax。"""
    assert x.is_cuda and x.dim() == 2
    n_rows, n_cols = x.shape

    # BLOCK_SIZE 必须是 2 的幂;取 >= n_cols 的最小 2 的幂
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # 根据 BLOCK_SIZE 启发式选择 num_warps
    num_warps = 4
    if BLOCK_SIZE >= 2048:
        num_warps = 8
    if BLOCK_SIZE >= 4096:
        num_warps = 16

    y = torch.empty_like(x)
    grid = (n_rows,)  # 简化版:每行一个程序实例
    softmax_kernel[grid](
        y, x,
        x.stride(0), y.stride(0),
        n_rows, n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )
    return y

7.3.3 融合的关键设计点

  1. 一行装进 SRAM 一次BLOCK_SIZE = next_pow2(n_cols) 让整行能放进单程序实例的寄存器 / SRAM,max / exp / sum / div 全在 on-chip 完成。

  2. 2 的幂约束 + mask + -inf 填充:Triton 要求块维度必须是 2 的幂。当 n_cols 不是 2 的幂时,超出的位置用 -inf 填充——刚好不影响 max 和 sum。

  3. 数值稳定:先减 max 再 exp,避免 fp16 时 exp(large_x) 溢出到 inf。这是所有生产级 softmax 实现的标配。

  4. tl.exp 是近似实现:类似 CUDA 的 __expf,比 expf 快很多但有轻微误差。对训练已经足够(梯度量级远大于这个误差)。

  5. 持久化核函数 (persistent kernel) 模式(可选):用 for row_idx in tl.range(row_start, n_rows, row_step) 让一个程序实例处理多行,减少程序实例启动开销。网格大小固定为 SM 数即可。

7.3.4 性能对比

官方教程在 A100 上(M=4096, fp32)的 benchmark 结果:

实现带宽 (GB/s)备注
torch.jit.script naive~150JIT 没融合,慢约 4×
torch.softmax(cuDNN)~600~700通用实现,可处理任意 shape
Triton 融合版~1500(接近 HBM 峰值)限制:n_cols ≤ SRAM 容量

Triton 比 torch.softmax 还快的原因

不是因为 Triton 算得快,而是 cuDNN 的 softmax 为了支持任意 shape 用了 temp memory 中转;Triton 知道 "n_cols ≤ 32K 时整行能放进 SRAM" 这个前提,所以可以走最激进的路径。

这也是 Triton 的核心价值——为特定场景写最优核函数,把通用库的妥协吃掉

7.4 常见融合模式

下表汇总了 Triton 实践中最常用的几种融合模式:

模式例子典型加速说明
Element-wise chain(x + bias) * scale2~3×最简单,常被 torch.compile 自动融合
Reduction + element-wiseSoftmax、LayerNorm、RMSNorm本章主案例
GEMM + activationgelu(matmul(x, w))1.3~1.5×利用 fp32 accumulator 直接计算激活
GEMM + epilogue normLayerNorm 融合到 attention output显著降总延迟LLM 推理常见
Attention(双 GEMM + softmax)FlashAttention 系列O(N²) → O(N) HBM 访问见第 8 章
Backward 融合dropout_bwd + gelu_bwd + linear_bwd训练阶段省 50% 中间张量节约显存比节约时间更重要

7.4.1 GEMM + activation 融合范式

在 matmul 核函数的尾部、accumulator 还在 fp32 寄存器里的时候插入激活函数:

python
# ... K 维循环之后 ...
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
    a = tl.load(a_ptrs, ...)
    b = tl.load(b_ptrs, ...)
    accumulator = tl.dot(a, b, accumulator)
    a_ptrs += BLOCK_SIZE_K * stride_ak
    b_ptrs += BLOCK_SIZE_K * stride_bk

# 在 fp32 累加器上直接做激活,省一次 round-trip
if ACTIVATION == "leaky_relu":
    accumulator = tl.where(accumulator >= 0, accumulator, 0.01 * accumulator)
elif ACTIVATION == "gelu":
    accumulator = 0.5 * accumulator * (1.0 + tl.tanh(
        0.7978845608 * (accumulator + 0.044715 * accumulator * accumulator * accumulator)
    ))

c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)

收益:省了一次 "matmul → 写 C → 再读 C → 算激活 → 再写 C" 的 round-trip,对 memory-bound 的 wide matmul 提速 1.3~1.5×。

7.4.2 什么时候不要融合

融合不是越多越好。下列场景应该谨慎或避免:

  1. 融合后 SRAM 爆掉:每多融一个算子就多一组中间变量;超过 SRAM 容量会触发寄存器溢出,反而变慢。
  2. 算子的并行模式不同:例如 reduce(axis=0) + reduce(axis=1) 沿不同维度归约,硬融的代价大于收益。
  3. 要复用中间结果:如果中间结果会被多个后续算子使用,把它 materialize 一次反而更省 HBM 总流量。
  4. 算子间有数据依赖:如 y = softmax(x); z = matmul(y, w),softmax 的输出要被 matmul 多个程序实例共享——通常不融合更好,而是把 softmax 输出当作一个新输入张量。

7.5 融合收益的定量判断框架

"什么时候该融合?"——别拍脑袋,列个表算账。本节给出一套可执行的定量判断流程。

7.5.1 核心公式

理想融合的加速比:

$$ \text{speedup} = \frac{\sum_i t_i^{\text{indiv}}}{t^{\text{fused}}} $$

由于 t ≈ max(t_compute, t_memory),对 memory-bound 算子(占 LLM 推理大头),约等于:

$$ \text{speedup}_{\text{mem-bound}} ;\approx; \frac{\sum_i B_i^{\text{HBM}}}{B^{\text{fused, HBM}}} $$

也就是 HBM 流量减少倍数 = 加速倍数。把每个候选算子的 HBM 读写都算出来,对比融合前后的总流量,就知道收益。

7.5.2 HBM 流量计算示例

算子naive HBM 流量fused HBM 流量节省
softmax([M,N])5MN + 2M 读 + 3MN + 2M 写 = ~8MN1MN 读 + 1MN 写 = 2MN
LayerNorm([M,N])4MN 读 + 3MN 写 = ~7MN2MN(无残差)3.5×
LayerNorm + residual([M,N])5MN 读 + 4MN 写 = ~9MN3MN 读(x、residual、γ)+ 1MN 写 = 4MN2.25×
GELU(matmul(X,W)) ([M,K]×[K,N])matmul: 2(MK+KN+MN);GELU: 2MN2(MK+KN) + MN~1.4×
dropout(GELU(linear(X))) 训练4 个独立 kernel1 个 fused~3×
RMSNorm + Residual 详细流量推导
text
naive:
  1) tmp = x + residual              read: 2MN, write: 1MN
  2) var = mean(tmp^2)               read: 1MN, write: 1M
  3) rstd = 1/sqrt(var + eps)        read: 1M, write: 1M  (M 量级,忽略)
  4) out = tmp * rstd * gamma        read: 2MN + 1M, write: 1MN
  total: read 5MN, write 3MN = 8MN

fused:
  对每行:
    load x、residual、gamma           read: 3MN(gamma 可被 SM 内复用,实际 < 3MN)
    on-chip: 算 tmp、var、rstd、out
    write out                         write: 1MN
  total: ~4MN

speedup ≈ 8 / 4 = 2× (实测 6× 是因为 PyTorch naive 还多算了几次 read)

实测对照(A100,hidden=4096,from bassrehab/triton-kernels):

KernelPyTorch BWTriton Fused BWSpeedup
RMSNorm168 GB/s (11% peak)1365 GB/s (88% peak)8.1×
RMSNorm + Residual266 GB/s (17%)1285 GB/s (83%)6.0×
SwiGLU1251 GB/s (80%)1223 GB/s (79%)1.6× ← SwiGLU PyTorch 已经接近峰值,融合主要省 1.6× 内存

注意 SwiGLU 已经 80% 峰值,融合带来的吞吐增益不大,但显存占用降 1.6×(少一次中间 activation 落地),这在 16K context 训练里是救命稻草。

7.5.3 何时融合收益为负

下面三种场景融合会让性能更差或与不融合持平:

1. Compute-bound 算子链

text
matmul(M,K) × matmul(K,N) × softmax  # 两个大 matmul + 一个小 softmax

两个 matmul 都是 compute-bound(AI ≈ K 量级,远 > ridge point),融合不仅不省时间,反而因为占用更多 SRAM 让 occupancy 降低。这是 FlashAttention 设计上把 attention 设计成单 kernel 的反例——FA 之所以有收益,是因为 softmax 在中间,且 attention matrix 太大装不下 HBM。

2. SRAM 溢出风险

python
# 融合 4 个算子,每个 BLOCK 都要 + 一份中间累加器
# BM × BN × 4 × 4 bytes (fp32) = 128 × 128 × 4 × 4 = 256 KB
# A100 SMEM 只有 164 KB → 寄存器 spill 到 local memory (实际是 HBM)
# 结果:融合"变慢" 30~50%

融合超过 4 个算子要警觉

经验法则:融合 2~3 个 memory-bound 算子稳收益;融合 4~5 个开始有 SRAM 风险;超过 5 个建议用 IR dump 看 triton_gpu.shared = ? 字段确认未溢出。

3. 算子并行模式不匹配

python
y = reduce(x, axis=0)   # [N, M] → [M]
z = reduce(y, axis=0)   # [M] → 标量

第一个 reduce 的并行模式是"M 个 program 各算一列",第二个是"全局 reduce"。融合到一个 kernel 里要么破坏第一个的并行,要么用 atomic(极慢)。最佳做法:保持两个独立 kernel + persistent kernel 流水。

7.5.4 决策流程图

text
┌──────────────────────────────────────┐
│ 候选算子链:op1 → op2 → ... → opN    │
└──────────────────┬───────────────────┘

   计算每个 op 的 arithmetic intensity AI

   是否全部 memory-bound (AI < ridge)?
        ├─ 否 → 不融合(有 compute-bound 算子)
        ↓ 是
   预估融合后 SRAM 占用 < SMEM × 0.6?
        ├─ 否 → 切分融合边界(融前 2 个 + 融后 2 个)
        ↓ 是
   算子并行模式一致(都按 row / 都按 col)?
        ├─ 否 → 不融合
        ↓ 是
   计算 HBM 流量节省 ≥ 1.5×?
        ├─ 否 → 不值得融合
        ↓ 是
   融合,写 kernel,benchmark 验证

7.6 Multi-operator DAG 融合分析

真实 Transformer block 不是 1 个 op 接 1 个 op 的链条,而是一个 DAG。融合边界怎么切才最优?

7.6.1 Transformer block 的 DAG

text
       residual_in


        ┌──────┐
   ┌───→│RMSNorm│
   │    └───┬──┘
   │        ↓
   │     QKV proj    ← weight
   │        ↓
   │   ┌────────┐
   │   │Attention│  (Flash)
   │   └────┬────┘
   │        ↓
   │     O proj      ← weight
   │        ↓
   └─────→ add  ← residual


        ┌──────┐
   ┌───→│RMSNorm│
   │    └───┬──┘
   │        ↓
   │     up_proj      ← weight
   │        ↓
   │     SwiGLU
   │        ↓
   │     down_proj    ← weight
   │        ↓
   └─────→ add

        residual_out

每个节点都是潜在的融合点。PyTorch Inductor 和手写 Triton 的做法不同:

7.6.2 Inductor 自动融合策略

PyTorch Inductor(torch.compile 后端)按 producer-consumer 链 做融合,规则:

融合规则示例Inductor 是否自动融合
pointwise + pointwise(x + bias) * scale✅ 总是融合
pointwise + reductionsum(x * y)✅ 当 reduction 在末尾
reduction + pointwisesoftmax(x) * mask⚠️ 有时融合,看 hidden dim
多输入 pointwiseadd + add + mul✅ 横向融合
matmul + pointwisebias + gelu(linear)⚠️ 通过 epilogue 融合
包含 reduction 的复杂 DAGLayerNorm 全流程❌ 一般拆成 2 个 kernel

社区测得 Inductor 对 Transformer block 一般生成 6~8 个 kernel(QKV、Attention、O、Norm、MLP up、SwiGLU、MLP down、Norm),而手写 Triton 可以压到 3~4 个(FlashAttention + fused QKV/O + fused MLP)。

7.6.3 手动融合 vs Inductor 自动融合的 tradeoff

维度手动 TritonInductor 自动
开发成本高(每个算子写 100~500 行)零(torch.compile 一行)
峰值性能接近硬件极限(85~95% peak)70~85% peak
跨硬件需要分硬件 autotuneInductor 内置支持
灵活性任意融合边界受规则约束
维护成本低(PyTorch 升级自动适配)
训练 backward 融合手写 backward 核函数自动生成
使用场景业务核心算子(Attention、MoE)边缘算子、研究迭代

社区共识(2026)

  • LLM 推理框架(vLLM、SGLang)—— 核心 attention、MoE 用手写 Triton/CUDA,外围 norm、激活让 Inductor 处理。
  • 训练框架(torchtitan、Llama-Factory)—— 优先 Liger-Kernel 这类预融合包,剩余靠 torch.compile

7.6.4 案例:RMSNorm + Residual + SiLU 三算子融合

设输入 x, residual ∈ [B, T, H](fp16),权重 γ ∈ [H],最终输出 [B, T, H]

Naive 三 kernel 版本

text
1) y1 = x + residual              # 读 2BTH,写 BTH
2) y2 = RMSNorm(y1, γ)            # 读 BTH + H,写 BTH
3) out = y2 * sigmoid(y2)         # 读 BTH,写 BTH
total: read 4BTH + H, write 3BTH
     = 7BTH (+ H ≪)

融合版本

python
@triton.jit
def fused_rmsnorm_residual_silu_kernel(
    x_ptr, residual_ptr, gamma_ptr, out_ptr,
    stride, N, eps,
    BLOCK_SIZE: tl.constexpr,
):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK_SIZE)
    mask = cols < N

    # 一次性 load 三个输入(驻留 SRAM)
    x   = tl.load(x_ptr + row * stride + cols, mask=mask, other=0.0).to(tl.float32)
    res = tl.load(residual_ptr + row * stride + cols, mask=mask, other=0.0).to(tl.float32)
    g   = tl.load(gamma_ptr + cols, mask=mask, other=0.0).to(tl.float32)

    # 流水线计算(全在寄存器)
    h = x + res                                     # residual add
    var = tl.sum(h * h, axis=0) / N                 # RMS 的分母
    rstd = 1.0 / tl.sqrt(var + eps)
    h_norm = h * rstd * g                           # RMSNorm
    out = h_norm * tl.sigmoid(h_norm)               # SiLU

    tl.store(out_ptr + row * stride + cols, out.to(tl.float16), mask=mask)

融合后 HBM 流量

text
read: x (BTH) + residual (BTH) + γ (H) = 2BTH + H
write: out (BTH) = BTH
total: 3BTH (+ H ≪)

理论加速比7BTH / 3BTH ≈ 2.3×

实测(A100,B=4, T=2048, H=4096):

实现HBM 流量时间带宽 (GB/s)
3 个独立 PyTorch kernel7BTH = 458 MB0.41 ms1117
单融合 Triton kernel3BTH = 196 MB0.18 ms1089
加速比2.3× ← 与理论值完全吻合

流量计算就是性能上限

对 memory-bound 融合,理论加速比 = HBM 流量节省比例。如果实测加速远低于理论值,说明:

  1. 不是真 memory-bound(看看 AI 是否 > ridge)
  2. SRAM 溢出(看 IR dump 的 shared 字段)
  3. 融合算子之间存在串行依赖(如 reduction 与下一个 reduction)

7.6.5 反例:什么时候 Inductor 已经够好

下面这些算子链,别手写,直接让 Inductor 处理:

python
# 1. 纯 elementwise(Inductor 必融合)
y = (x + bias) * scale + offset

# 2. broadcast 类
y = x * gamma[None, :] + beta[None, :]

# 3. 简单 reduction 后接 broadcast
mean = x.mean(dim=-1, keepdim=True)
y = x - mean

# 4. 跨 layer 的 dead code elimination
@torch.compile(mode='max-autotune')
def block(x):
    y = norm(x)
    return y + x  # Inductor 会自动认出 y 用过即可释放

经验法则:算子链 ≤ 3 个 + 纯 pointwise → Inductor包含 reduction 或要复用中间值 → 手写 Triton

本章小结

  • 算子融合的本质:把多个连续算子合并到一个核函数,让中间结果只在寄存器 / SRAM 里流转,减少 HBM 往返
  • memory-bound 算子是融合的最大受益者:LLM 中 39% 的运行时间花在 0.02% FLOPs 的小算子上,融合一次能省 4× 带宽。
  • Triton 在融合场景的甜点:比 CUDA 简洁(~1/4 行数),比 torch.compile 可控;尤其擅长 reduction + element-wise 复杂融合。
  • 融合 Softmax 是经典教学案例:通过 tl.load → tl.max → tl.exp → tl.sum → tl.store 一次完成全流程,A100 上可达 HBM 峰值带宽。
  • 判断融合收益用流量公式:对 memory-bound 算子,加速比 ≈ HBM 流量节省比例;RMSNorm+Residual+SiLU 三算子融合实测 2.3× 与理论值完全吻合。
  • 常见融合模式:element-wise chain、reduction + element-wise、GEMM + activation、attention(见下一章)、backward 融合。
  • 融合的边界:SRAM 容量、并行模式差异、数据复用、跨核函数依赖——融合不是万灵药,融合超过 4~5 算子要警觉 SRAM 溢出。
  • 手动 vs Inductor:核心算子手写(85~95% peak),边缘算子让 Inductor 处理(70~85% peak);2026 主流推理框架都是混合策略。

下一章我们把"融合"推向极致——FlashAttention 把 attention 的两个 GEMM + softmax 全部塞进一个核函数,并通过"在线 softmax + 不 materialize 注意力矩阵"把 HBM 访问从 O(N²) 降到 O(N),是融合思想的巅峰案例。

思考题

  1. 给定如下 PyTorch 代码:

    python
    y = torch.exp(x - x.max(dim=-1, keepdim=True)[0])
    out = y / y.sum(dim=-1, keepdim=True) * scale + bias

    请估算朴素 PyTorch eager 模式与 Triton 融合核函数的 HBM 流量比,并说明融合后理论加速比。

  2. 7.3 节的融合 softmax 要求 n_cols ≤ ~32K(一行能装进 SRAM)。如果 n_cols = 200000(如超大词表的 logits),单行装不进 SRAM 怎么办?请草拟一个多块 softmax 的设计思路(提示:online algorithm + 两阶段核函数)。

  3. 你想把 dropout(gelu(linear(x))) 三个算子融成一个 Triton 核函数用于训练。请列出至少 3 个需要注意的设计点(提示:考虑反向传播、PRNG 状态、SRAM 预算)。

基于 MIT 协议发布