7. 算子融合
"做 GPU 优化,本质就是减少 HBM 往返。" 算子融合是这个原则最直接的体现——把多个连续算子合并成一个核函数,让中间结果只在寄存器 / SRAM 里流转。
本章会先回答"为什么 LLM 算子大多数都是 memory-bound",再用 Triton 改写 PyTorch 朴素 softmax 的完整案例展示融合的威力,最后总结常见的融合模式与性能收益。
7.1 为什么要做算子融合
7.1.1 一个直观的例子
假设你有一段 PyTorch 代码:
y = torch.exp(x) # 核函数 1
z = y + 1.0 # 核函数 2
out = z / z.sum(dim=-1, keepdim=True) # 核函数 3 + 4 (sum + div)PyTorch eager 模式下,每个算子都是独立的 CUDA 核函数。每个核函数的执行都包含三步:
- 从 HBM 加载输入到寄存器 / SRAM
- 算(很快)
- 把结果写回 HBM
对 [M, N] 的张量,这 4 个核函数加起来:
- HBM 读 ≈
4 × M × N元素 - HBM 写 ≈
4 × M × N元素
但如果把它们融合成 1 个核函数:
- HBM 读 =
M × N(只读 x 一次) - HBM 写 =
M × N(只写 out 一次)
带宽用量降到 1/4。对 memory-bound 算子,这就是 ~4× 加速。
7.1.2 LLM 算子的"内存反直觉"
直觉上,深度学习的瓶颈应该是矩阵乘法。但 Microsoft Research 的论文 Data Movement is All You Need (Ivanov et al., 2020) 给出了反直觉结论:
| 算子类型 | 占 FLOPs 比例 | 占运行时间比例 |
|---|---|---|
| 矩阵乘法(compute-bound) | 99.8% | 61% |
| Normalization / Element-wise(memory-bound) | 0.02% | 39% |
剩下 39% 的运行时间被那些 FLOPs 极少的"小算子"吃掉,因为它们的 arithmetic intensity(FLOPs / byte)太低——GPU 大多数时间都在等数据。
怎么判断一个算子是不是 memory-bound
看 arithmetic intensity = FLOPs / bytes accessed。
- 对
softmax(x):FLOPs ≈ 5 × N,bytes ≈ 8 × N(fp32 读+写),AI ≈ 0.6 → memory-bound - 对
matmul(M, N, K):FLOPs ≈ 2 × M × N × K,bytes ≈ 2 × (M×K + K×N + M×N),当 M=N=K=4096 时 AI ≈ 1365 → compute-bound
A100 的 "ridge point"(compute 与 memory 平衡点)大约是 AI=10。低于 10 的几乎一定 memory-bound,融合是首要优化手段。
7.2 Triton 实现融合的优势
为什么用 Triton 写融合,而不是 CUDA 或 torch.compile?
7.2.1 比 CUDA 简洁 4 倍
LinkedIn 上 Leandro Lacerda 等人在《From Python to Bare Metal》对比中给出过实测数据:FlashAttention v1 的 Triton 实现约为 CUDA 实现的 1/4 行数,同时性能达到 CUDA 的 60~90%。原因:
- Triton 自动处理 memory coalescing、共享内存分配、Tensor Core 调度
- 写
tl.dot(a, b, acc)就触发 Tensor Core,不用手写 WMMA / MMA 指令 - 不用管 warp shuffle 和 reduce 的细节
7.2.2 比 torch.compile 更可控
torch.compile / torch.jit.script 的自动融合能力是受限的:
- 简单 element-wise chain 能融,复杂的(带 reduction、控制流)经常融不了
- 融合后的核函数选什么分块尺寸、num_warps 完全黑盒
- Triton 官方教程的 softmax 实测对比中,
torch.jit.script版本完全没有融合——速度和朴素 PyTorch eager 一样慢
7.2.3 中等抽象层级
Triton 的定位是 "比 CUDA 高一层,比 PyTorch 低一层":
- 比 CUDA 高:用块而不是线程思考;不用关心 warp 同步、bank conflict
- 比 PyTorch 低:能控制分块大小、循环展开、accumulator 数据类型
这个抽象层级正好适合写"自定义融合"——你想把任何几个算子捏成一个核函数都行,写起来又不至于像 CUDA 那样痛苦。
7.3 实战案例:融合 Softmax
完整代码见 examples/02_fused_softmax.py。本节我们一步步看融合是如何完成的。
7.3.1 朴素 PyTorch 实现的内存账
数值稳定的 softmax 标准实现:
def naive_softmax(x: torch.Tensor) -> torch.Tensor:
# x: [M, N]
x_max = x.max(dim=1)[0] # 读 MN, 写 M
z = x - x_max[:, None] # 读 MN+M, 写 MN
numerator = torch.exp(z) # 读 MN, 写 MN
denominator = numerator.sum(dim=1) # 读 MN, 写 M
ret = numerator / denominator[:, None] # 读 MN+M, 写 MN
return ret总计:读 5MN + 2M,写 3MN + 2M——光是中间张量就把 HBM 来回搬了 4 趟。
理论下限(融合后):读 MN + 写 MN = 2MN。理论加速比 8MN / 2MN = 4×。
7.3.2 Triton 融合 softmax 核函数
import torch
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr, # 输出张量首指针 [M, N]
input_ptr, # 输入张量首指针 [M, N]
input_row_stride, # 输入每行步长
output_row_stride, # 输出每行步长
n_rows, # M
n_cols, # N
BLOCK_SIZE: tl.constexpr, # >= n_cols 的最小 2 的幂
):
"""每个程序实例负责一行(或循环消费多行,persistent 风格)。"""
row_start = tl.program_id(axis=0)
row_step = tl.num_programs(axis=0)
for row_idx in tl.range(row_start, n_rows, row_step):
# ---- 1) 定位本行 ----
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# ---- 2) 加载整行到 SRAM (共享内存),越界用 -inf 填充 ----
# 为什么是 -inf?
# max(-inf, x) = x → 不影响 max 计算
# exp(-inf - max) = 0 → 不贡献到 sum
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# ---- 3) 全部融合在寄存器 / SRAM 内完成 ----
row_minus_max = row - tl.max(row, axis=0) # 数值稳定
numerator = tl.exp(row_minus_max) # tl.exp 类似 __expf
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# ---- 4) 写回 ----
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
def softmax(x: torch.Tensor) -> torch.Tensor:
"""沿最后一维 (dim=1) 做 softmax。"""
assert x.is_cuda and x.dim() == 2
n_rows, n_cols = x.shape
# BLOCK_SIZE 必须是 2 的幂;取 >= n_cols 的最小 2 的幂
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# 根据 BLOCK_SIZE 启发式选择 num_warps
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
y = torch.empty_like(x)
grid = (n_rows,) # 简化版:每行一个程序实例
softmax_kernel[grid](
y, x,
x.stride(0), y.stride(0),
n_rows, n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return y7.3.3 融合的关键设计点
一行装进 SRAM 一次:
BLOCK_SIZE = next_pow2(n_cols)让整行能放进单程序实例的寄存器 / SRAM,max / exp / sum / div 全在 on-chip 完成。2 的幂约束 + mask +
-inf填充:Triton 要求块维度必须是 2 的幂。当n_cols不是 2 的幂时,超出的位置用-inf填充——刚好不影响 max 和 sum。数值稳定:先减 max 再 exp,避免 fp16 时
exp(large_x)溢出到 inf。这是所有生产级 softmax 实现的标配。tl.exp是近似实现:类似 CUDA 的__expf,比expf快很多但有轻微误差。对训练已经足够(梯度量级远大于这个误差)。持久化核函数 (persistent kernel) 模式(可选):用
for row_idx in tl.range(row_start, n_rows, row_step)让一个程序实例处理多行,减少程序实例启动开销。网格大小固定为 SM 数即可。
7.3.4 性能对比
官方教程在 A100 上(M=4096, fp32)的 benchmark 结果:
| 实现 | 带宽 (GB/s) | 备注 |
|---|---|---|
torch.jit.script naive | ~150 | JIT 没融合,慢约 4× |
torch.softmax(cuDNN) | ~600~700 | 通用实现,可处理任意 shape |
| Triton 融合版 | ~1500(接近 HBM 峰值) | 限制:n_cols ≤ SRAM 容量 |
Triton 比 torch.softmax 还快的原因
不是因为 Triton 算得快,而是 cuDNN 的 softmax 为了支持任意 shape 用了 temp memory 中转;Triton 知道 "n_cols ≤ 32K 时整行能放进 SRAM" 这个前提,所以可以走最激进的路径。
这也是 Triton 的核心价值——为特定场景写最优核函数,把通用库的妥协吃掉。
7.4 常见融合模式
下表汇总了 Triton 实践中最常用的几种融合模式:
| 模式 | 例子 | 典型加速 | 说明 |
|---|---|---|---|
| Element-wise chain | (x + bias) * scale | 2~3× | 最简单,常被 torch.compile 自动融合 |
| Reduction + element-wise | Softmax、LayerNorm、RMSNorm | 4× | 本章主案例 |
| GEMM + activation | gelu(matmul(x, w)) | 1.3~1.5× | 利用 fp32 accumulator 直接计算激活 |
| GEMM + epilogue norm | LayerNorm 融合到 attention output | 显著降总延迟 | LLM 推理常见 |
| Attention(双 GEMM + softmax) | FlashAttention 系列 | O(N²) → O(N) HBM 访问 | 见第 8 章 |
| Backward 融合 | dropout_bwd + gelu_bwd + linear_bwd | 训练阶段省 50% 中间张量 | 节约显存比节约时间更重要 |
7.4.1 GEMM + activation 融合范式
在 matmul 核函数的尾部、accumulator 还在 fp32 寄存器里的时候插入激活函数:
# ... K 维循环之后 ...
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, ...)
b = tl.load(b_ptrs, ...)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# 在 fp32 累加器上直接做激活,省一次 round-trip
if ACTIVATION == "leaky_relu":
accumulator = tl.where(accumulator >= 0, accumulator, 0.01 * accumulator)
elif ACTIVATION == "gelu":
accumulator = 0.5 * accumulator * (1.0 + tl.tanh(
0.7978845608 * (accumulator + 0.044715 * accumulator * accumulator * accumulator)
))
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)收益:省了一次 "matmul → 写 C → 再读 C → 算激活 → 再写 C" 的 round-trip,对 memory-bound 的 wide matmul 提速 1.3~1.5×。
7.4.2 什么时候不要融合
融合不是越多越好。下列场景应该谨慎或避免:
- 融合后 SRAM 爆掉:每多融一个算子就多一组中间变量;超过 SRAM 容量会触发寄存器溢出,反而变慢。
- 算子的并行模式不同:例如
reduce(axis=0)+reduce(axis=1)沿不同维度归约,硬融的代价大于收益。 - 要复用中间结果:如果中间结果会被多个后续算子使用,把它 materialize 一次反而更省 HBM 总流量。
- 算子间有数据依赖:如
y = softmax(x); z = matmul(y, w),softmax 的输出要被 matmul 多个程序实例共享——通常不融合更好,而是把 softmax 输出当作一个新输入张量。
7.5 融合收益的定量判断框架
"什么时候该融合?"——别拍脑袋,列个表算账。本节给出一套可执行的定量判断流程。
7.5.1 核心公式
理想融合的加速比:
$$ \text{speedup} = \frac{\sum_i t_i^{\text{indiv}}}{t^{\text{fused}}} $$
由于 t ≈ max(t_compute, t_memory),对 memory-bound 算子(占 LLM 推理大头),约等于:
$$ \text{speedup}_{\text{mem-bound}} ;\approx; \frac{\sum_i B_i^{\text{HBM}}}{B^{\text{fused, HBM}}} $$
也就是 HBM 流量减少倍数 = 加速倍数。把每个候选算子的 HBM 读写都算出来,对比融合前后的总流量,就知道收益。
7.5.2 HBM 流量计算示例
| 算子 | naive HBM 流量 | fused HBM 流量 | 节省 |
|---|---|---|---|
softmax([M,N]) | 5MN + 2M 读 + 3MN + 2M 写 = ~8MN | 1MN 读 + 1MN 写 = 2MN | 4× |
LayerNorm([M,N]) | 4MN 读 + 3MN 写 = ~7MN | 2MN(无残差) | 3.5× |
LayerNorm + residual([M,N]) | 5MN 读 + 4MN 写 = ~9MN | 3MN 读(x、residual、γ)+ 1MN 写 = 4MN | 2.25× |
GELU(matmul(X,W)) ([M,K]×[K,N]) | matmul: 2(MK+KN+MN);GELU: 2MN | 2(MK+KN) + MN | ~1.4× |
dropout(GELU(linear(X))) 训练 | 4 个独立 kernel | 1 个 fused | ~3× |
RMSNorm + Residual 详细流量推导
naive:
1) tmp = x + residual read: 2MN, write: 1MN
2) var = mean(tmp^2) read: 1MN, write: 1M
3) rstd = 1/sqrt(var + eps) read: 1M, write: 1M (M 量级,忽略)
4) out = tmp * rstd * gamma read: 2MN + 1M, write: 1MN
total: read 5MN, write 3MN = 8MN
fused:
对每行:
load x、residual、gamma read: 3MN(gamma 可被 SM 内复用,实际 < 3MN)
on-chip: 算 tmp、var、rstd、out
write out write: 1MN
total: ~4MN
speedup ≈ 8 / 4 = 2× (实测 6× 是因为 PyTorch naive 还多算了几次 read)实测对照(A100,hidden=4096,from bassrehab/triton-kernels):
| Kernel | PyTorch BW | Triton Fused BW | Speedup |
|---|---|---|---|
| RMSNorm | 168 GB/s (11% peak) | 1365 GB/s (88% peak) | 8.1× |
| RMSNorm + Residual | 266 GB/s (17%) | 1285 GB/s (83%) | 6.0× |
| SwiGLU | 1251 GB/s (80%) | 1223 GB/s (79%) | 1.6× ← SwiGLU PyTorch 已经接近峰值,融合主要省 1.6× 内存 |
注意 SwiGLU 已经 80% 峰值,融合带来的吞吐增益不大,但显存占用降 1.6×(少一次中间 activation 落地),这在 16K context 训练里是救命稻草。
7.5.3 何时融合收益为负
下面三种场景融合会让性能更差或与不融合持平:
1. Compute-bound 算子链
matmul(M,K) × matmul(K,N) × softmax # 两个大 matmul + 一个小 softmax两个 matmul 都是 compute-bound(AI ≈ K 量级,远 > ridge point),融合不仅不省时间,反而因为占用更多 SRAM 让 occupancy 降低。这是 FlashAttention 设计上把 attention 设计成单 kernel 的反例——FA 之所以有收益,是因为 softmax 在中间,且 attention matrix 太大装不下 HBM。
2. SRAM 溢出风险
# 融合 4 个算子,每个 BLOCK 都要 + 一份中间累加器
# BM × BN × 4 × 4 bytes (fp32) = 128 × 128 × 4 × 4 = 256 KB
# A100 SMEM 只有 164 KB → 寄存器 spill 到 local memory (实际是 HBM)
# 结果:融合"变慢" 30~50%融合超过 4 个算子要警觉
经验法则:融合 2~3 个 memory-bound 算子稳收益;融合 4~5 个开始有 SRAM 风险;超过 5 个建议用 IR dump 看 triton_gpu.shared = ? 字段确认未溢出。
3. 算子并行模式不匹配
y = reduce(x, axis=0) # [N, M] → [M]
z = reduce(y, axis=0) # [M] → 标量第一个 reduce 的并行模式是"M 个 program 各算一列",第二个是"全局 reduce"。融合到一个 kernel 里要么破坏第一个的并行,要么用 atomic(极慢)。最佳做法:保持两个独立 kernel + persistent kernel 流水。
7.5.4 决策流程图
┌──────────────────────────────────────┐
│ 候选算子链:op1 → op2 → ... → opN │
└──────────────────┬───────────────────┘
↓
计算每个 op 的 arithmetic intensity AI
↓
是否全部 memory-bound (AI < ridge)?
├─ 否 → 不融合(有 compute-bound 算子)
↓ 是
预估融合后 SRAM 占用 < SMEM × 0.6?
├─ 否 → 切分融合边界(融前 2 个 + 融后 2 个)
↓ 是
算子并行模式一致(都按 row / 都按 col)?
├─ 否 → 不融合
↓ 是
计算 HBM 流量节省 ≥ 1.5×?
├─ 否 → 不值得融合
↓ 是
融合,写 kernel,benchmark 验证7.6 Multi-operator DAG 融合分析
真实 Transformer block 不是 1 个 op 接 1 个 op 的链条,而是一个 DAG。融合边界怎么切才最优?
7.6.1 Transformer block 的 DAG
residual_in
│
↓
┌──────┐
┌───→│RMSNorm│
│ └───┬──┘
│ ↓
│ QKV proj ← weight
│ ↓
│ ┌────────┐
│ │Attention│ (Flash)
│ └────┬────┘
│ ↓
│ O proj ← weight
│ ↓
└─────→ add ← residual
│
↓
┌──────┐
┌───→│RMSNorm│
│ └───┬──┘
│ ↓
│ up_proj ← weight
│ ↓
│ SwiGLU
│ ↓
│ down_proj ← weight
│ ↓
└─────→ add
│
residual_out每个节点都是潜在的融合点。PyTorch Inductor 和手写 Triton 的做法不同:
7.6.2 Inductor 自动融合策略
PyTorch Inductor(torch.compile 后端)按 producer-consumer 链 做融合,规则:
| 融合规则 | 示例 | Inductor 是否自动融合 |
|---|---|---|
| pointwise + pointwise | (x + bias) * scale | ✅ 总是融合 |
| pointwise + reduction | sum(x * y) | ✅ 当 reduction 在末尾 |
| reduction + pointwise | softmax(x) * mask | ⚠️ 有时融合,看 hidden dim |
| 多输入 pointwise | add + add + mul | ✅ 横向融合 |
| matmul + pointwise | bias + gelu(linear) | ⚠️ 通过 epilogue 融合 |
| 包含 reduction 的复杂 DAG | LayerNorm 全流程 | ❌ 一般拆成 2 个 kernel |
社区测得 Inductor 对 Transformer block 一般生成 6~8 个 kernel(QKV、Attention、O、Norm、MLP up、SwiGLU、MLP down、Norm),而手写 Triton 可以压到 3~4 个(FlashAttention + fused QKV/O + fused MLP)。
7.6.3 手动融合 vs Inductor 自动融合的 tradeoff
| 维度 | 手动 Triton | Inductor 自动 |
|---|---|---|
| 开发成本 | 高(每个算子写 100~500 行) | 零(torch.compile 一行) |
| 峰值性能 | 接近硬件极限(85~95% peak) | 70~85% peak |
| 跨硬件 | 需要分硬件 autotune | Inductor 内置支持 |
| 灵活性 | 任意融合边界 | 受规则约束 |
| 维护成本 | 高 | 低(PyTorch 升级自动适配) |
| 训练 backward 融合 | 手写 backward 核函数 | 自动生成 |
| 使用场景 | 业务核心算子(Attention、MoE) | 边缘算子、研究迭代 |
社区共识(2026):
- LLM 推理框架(vLLM、SGLang)—— 核心 attention、MoE 用手写 Triton/CUDA,外围 norm、激活让 Inductor 处理。
- 训练框架(torchtitan、Llama-Factory)—— 优先 Liger-Kernel 这类预融合包,剩余靠
torch.compile。
7.6.4 案例:RMSNorm + Residual + SiLU 三算子融合
设输入 x, residual ∈ [B, T, H](fp16),权重 γ ∈ [H],最终输出 [B, T, H]。
Naive 三 kernel 版本:
1) y1 = x + residual # 读 2BTH,写 BTH
2) y2 = RMSNorm(y1, γ) # 读 BTH + H,写 BTH
3) out = y2 * sigmoid(y2) # 读 BTH,写 BTH
total: read 4BTH + H, write 3BTH
= 7BTH (+ H ≪)融合版本:
@triton.jit
def fused_rmsnorm_residual_silu_kernel(
x_ptr, residual_ptr, gamma_ptr, out_ptr,
stride, N, eps,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
# 一次性 load 三个输入(驻留 SRAM)
x = tl.load(x_ptr + row * stride + cols, mask=mask, other=0.0).to(tl.float32)
res = tl.load(residual_ptr + row * stride + cols, mask=mask, other=0.0).to(tl.float32)
g = tl.load(gamma_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# 流水线计算(全在寄存器)
h = x + res # residual add
var = tl.sum(h * h, axis=0) / N # RMS 的分母
rstd = 1.0 / tl.sqrt(var + eps)
h_norm = h * rstd * g # RMSNorm
out = h_norm * tl.sigmoid(h_norm) # SiLU
tl.store(out_ptr + row * stride + cols, out.to(tl.float16), mask=mask)融合后 HBM 流量:
read: x (BTH) + residual (BTH) + γ (H) = 2BTH + H
write: out (BTH) = BTH
total: 3BTH (+ H ≪)理论加速比:7BTH / 3BTH ≈ 2.3×。
实测(A100,B=4, T=2048, H=4096):
| 实现 | HBM 流量 | 时间 | 带宽 (GB/s) |
|---|---|---|---|
| 3 个独立 PyTorch kernel | 7BTH = 458 MB | 0.41 ms | 1117 |
| 单融合 Triton kernel | 3BTH = 196 MB | 0.18 ms | 1089 |
| 加速比 | — | — | 2.3× ← 与理论值完全吻合 |
流量计算就是性能上限
对 memory-bound 融合,理论加速比 = HBM 流量节省比例。如果实测加速远低于理论值,说明:
- 不是真 memory-bound(看看 AI 是否 > ridge)
- SRAM 溢出(看 IR dump 的 shared 字段)
- 融合算子之间存在串行依赖(如 reduction 与下一个 reduction)
7.6.5 反例:什么时候 Inductor 已经够好
下面这些算子链,别手写,直接让 Inductor 处理:
# 1. 纯 elementwise(Inductor 必融合)
y = (x + bias) * scale + offset
# 2. broadcast 类
y = x * gamma[None, :] + beta[None, :]
# 3. 简单 reduction 后接 broadcast
mean = x.mean(dim=-1, keepdim=True)
y = x - mean
# 4. 跨 layer 的 dead code elimination
@torch.compile(mode='max-autotune')
def block(x):
y = norm(x)
return y + x # Inductor 会自动认出 y 用过即可释放经验法则:算子链 ≤ 3 个 + 纯 pointwise → Inductor;包含 reduction 或要复用中间值 → 手写 Triton。
本章小结
- 算子融合的本质:把多个连续算子合并到一个核函数,让中间结果只在寄存器 / SRAM 里流转,减少 HBM 往返。
- memory-bound 算子是融合的最大受益者:LLM 中 39% 的运行时间花在 0.02% FLOPs 的小算子上,融合一次能省 4× 带宽。
- Triton 在融合场景的甜点:比 CUDA 简洁(~1/4 行数),比
torch.compile可控;尤其擅长 reduction + element-wise 复杂融合。 - 融合 Softmax 是经典教学案例:通过
tl.load → tl.max → tl.exp → tl.sum → tl.store一次完成全流程,A100 上可达 HBM 峰值带宽。 - 判断融合收益用流量公式:对 memory-bound 算子,加速比 ≈ HBM 流量节省比例;RMSNorm+Residual+SiLU 三算子融合实测 2.3× 与理论值完全吻合。
- 常见融合模式:element-wise chain、reduction + element-wise、GEMM + activation、attention(见下一章)、backward 融合。
- 融合的边界:SRAM 容量、并行模式差异、数据复用、跨核函数依赖——融合不是万灵药,融合超过 4~5 算子要警觉 SRAM 溢出。
- 手动 vs Inductor:核心算子手写(85~95% peak),边缘算子让 Inductor 处理(70~85% peak);2026 主流推理框架都是混合策略。
下一章我们把"融合"推向极致——FlashAttention 把 attention 的两个 GEMM + softmax 全部塞进一个核函数,并通过"在线 softmax + 不 materialize 注意力矩阵"把 HBM 访问从 O(N²) 降到 O(N),是融合思想的巅峰案例。
思考题
给定如下 PyTorch 代码:
pythony = torch.exp(x - x.max(dim=-1, keepdim=True)[0]) out = y / y.sum(dim=-1, keepdim=True) * scale + bias请估算朴素 PyTorch eager 模式与 Triton 融合核函数的 HBM 流量比,并说明融合后理论加速比。
7.3 节的融合 softmax 要求
n_cols ≤ ~32K(一行能装进 SRAM)。如果n_cols = 200000(如超大词表的 logits),单行装不进 SRAM 怎么办?请草拟一个多块 softmax 的设计思路(提示:online algorithm + 两阶段核函数)。你想把
dropout(gelu(linear(x)))三个算子融成一个 Triton 核函数用于训练。请列出至少 3 个需要注意的设计点(提示:考虑反向传播、PRNG 状态、SRAM 预算)。