8. FlashAttention 实战
用 Triton 一步步实现 FlashAttention,理解推动大模型长序列训练的关键算法。
FlashAttention(Tri Dao 等人,2022 / 2023)是近年最具影响力的 GPU 算法之一。它通过 分块计算 + 在线 softmax + 不 materialize 注意力矩阵,把 attention 的 HBM 访问从 O(N²) 降到 O(N)、显存从 O(N²) 降到 O(N),使得 GPT-3.5/4、LLaMA、Mistral 等模型能在合理的显存预算下训练数千乃至上万 token 的上下文。
本章带你从问题分析出发,一步步重写出官方 06-fused-attention.py 的核心思路,并对照配套示例 examples/04_flash_attention.py 逐段拆解。
本章内容概览
- 标准 attention 的内存瓶颈分析
- FlashAttention 的核心思想:分块 (tiling) + 在线 softmax
- 在线 softmax 算法推导
- Triton 实现详解(核函数逐段讲解)
- 性能分析与版本演进
- 进一步扩展:causal mask、backward、FP8
8.1 标准 Attention 的内存瓶颈
设 Q、K、V 形状均为 [N, d](N = seq_len,d = head_dim),标准 scaled-dot-product attention 的算式是:
$$ S = \frac{QK^\top}{\sqrt{d}}, \quad P = \mathrm{softmax}(S), \quad O = PV $$
朴素 PyTorch 实现:
def naive_attention(q, k, v, sm_scale):
s = torch.matmul(q, k.transpose(-2, -1)) * sm_scale # [N, N]
p = torch.softmax(s, dim=-1) # [N, N]
o = torch.matmul(p, v) # [N, d]
return o8.1.1 内存爆炸的本质
中间张量 S、P 的形状是 [N, N]。当 N = 4096、单 head、fp16 时:
$$ 4096 \times 4096 \times 2,\text{B} = 32,\text{MB} $$
而一张 H100 SM 的共享内存只有 228 KB,A100 是 192 KB。S 根本放不进 SRAM,必须落到 HBM,于是:
| 步骤 | HBM 读 | HBM 写 |
|---|---|---|
算 S = QK^T | 读 Q (Nd), 读 K (Nd) | 写 S (N²) |
算 P = softmax(S) | 读 S (N²) | 写 P (N²) |
算 O = PV | 读 P (N²), 读 V (Nd) | 写 O (Nd) |
合计 HBM 访问量约 3N² + 4Nd,随序列长度 平方级增长。当 N 从 1024 → 8192 时 HBM 流量增加 64 倍,而算力 (FLOPs) 同样增加 64 倍——也就是说,算力增长被白白浪费在搬数据上,attention 在长序列下从计算密集型沦为访存密集型。
关键瓶颈
不是 GPU 算不动 attention,而是 HBM 带宽喂不饱 Tensor Core。
8.1.2 显存峰值
更糟糕的是,训练时需要在 backward 阶段重用 P,朴素实现必须把 [B, H, N, N] 的 P 矩阵保留到 backward,显存占用直接随 N² 爆炸。8192 token、32 head、batch=4、fp16 单是 P 就要 64 GB——这就是为什么早期 GPT-2/3 训练 context 长度卡在 1K~2K。
8.2 FlashAttention 的核心思想
Tri Dao 在论文里给出了三个观察:
- 不需要看到完整的 S 才能做 softmax——softmax 可以"流式"地一块块算(online softmax)。
- 如果 P 不存储,O 也能直接算出来——把 P 留在寄存器里立刻乘 V。
- 如果 backward 不存 P,可以用 (m, l) 重算——多算一次 forward,省下 O(N²) 显存。
合起来:
FlashAttention 三连
Tiling(分块) + Online softmax(在线归一化) + Recomputation(重算)
效果:
- HBM 访问
O(N²)→O(N²d / M),M 是 SRAM 大小,实际等价于 O(N) - 显存
O(N²)→O(N)(只存最终的 O 和 forward 阶段的m, l) - 算力利用率:A100 上 fp16 attention 从 ~30% peak → ~70% peak
8.3 Online Softmax 推导
这是整个算法的"魔法"所在,必须搞清楚。
8.3.1 数值稳定的标准 softmax
给定向量 x = [x_1, ..., x_n],softmax 通常这样算:
m = max(x) # 减去最大值,防 exp 溢出
p_i = exp(x_i - m)
l = sum(p_i)
y_i = p_i / l这要 三次扫描:算 max、算 exp+sum、算 div。但我们能不能 一次扫描 就出结果?
8.3.2 增量公式
把 x 切成两段 x = [x_a, x_b]。先用 x_a 算出局部统计量:
m_a = max(x_a)
l_a = sum(exp(x_a - m_a))再看到 x_b,怎样合并?设:
m_b = max(x_b)
l_b = sum(exp(x_b - m_b))合并后的全局最大值:
$$ m_{\text{new}} = \max(m_a, m_b) $$
我们需要把两段的 l 都"对齐"到 m_new:
$$ l_{\text{new}} = e^{m_a - m_{\text{new}}} \cdot l_a + e^{m_b - m_{\text{new}}} \cdot l_b $$
这就是关键——旧的 l 只要乘一个修正系数 α = exp(m_old - m_new) 就能继续用,不必从头扫一遍。
8.3.3 把 O 一起推上去
attention 真正要的不是 softmax 本身,而是 O = softmax(S) · V。把 V 也按行分块,记 V 的对应块为 V_a, V_b,并定义"未归一化"的输出累加器:
$$ \tilde O = \sum_i e^{S_i - m} \cdot V_i $$
最终 O = \tilde O / l。增量公式变成:
m_new = max(m_old, max(S_block))
α = exp(m_old - m_new)
p = exp(S_block - m_new) # 局部归一化的 p
l_new = α * l_old + sum(p)
Õ_new = α * Õ_old + p @ V_block一句话总结
新块到来时:把旧的 (l, Õ) 乘上修正因子 α,再加上新块的贡献。 循环结束后 O = Õ / l。
这就是 FlashAttention forward 的全部数学。所有"看起来很复杂"的工程细节,都在围绕这几行公式做高效的 GPU 实现。
8.4 Triton 实现总览
我们对照 examples/04_flash_attention.py 一节节看。该示例是教学简化版(forward only、非 causal、不启用 warp_specialize / TMA),逻辑上是官方 06-fused-attention.py 的真子集,便于聚焦核心算法。
8.4.1 网格与程序实例划分
grid = (triton.cdiv(N_CTX, BLOCK_M), BATCH * HEADS, 1)axis=0:沿 Q 方向的块索引——每个程序实例负责输出 O 的一块[BLOCK_M, HEAD_DIM]axis=1:把 batch 和 head 平铺成一维——同一个程序实例只看一个 head 的数据
为什么 Q 分块、KV 不分块
FlashAttention v2 的关键改进就是把 outer loop 放在 Q 上、inner loop 放在 KV 上。这样 Q 块被加载一次后就可以全程驻留 SRAM;KV 在内循环中滑动,只占用一份分块的 SRAM。v1 反过来,导致 Q 反复 reload,吞吐低 ~2 倍。
8.4.2 加载 Q 并驻留 SRAM
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # 当前 Q 行索引
offs_k = tl.arange(0, HEAD_DIM) # head_dim 列索引
q_ptrs = Q + q_offset + (offs_m[:, None] * stride_qm
+ offs_k[None, :] * stride_qk)
q_mask = offs_m[:, None] < N_CTX
q = tl.load(q_ptrs, mask=q_mask, other=0.0)- 用
offs_m[:, None] * stride_m + offs_k[None, :] * stride_k这种广播写法构造 2D 指针块——Triton 中所有"二维分块"都是这么生成的。 q_mask处理N_CTX不是BLOCK_M整数倍的边界。- 这个
q张量在整个 K/V 循环中一直驻留寄存器,不再回 DRAM——这是 FA2 的核心优化。
8.4.3 初始化在线 softmax 累加器
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)对应 8.3 节的 m, l, Õ:
m_i:每行的 running max,初值-infl_i:每行的 running sum of exp,初值 0acc:每行的"未归一化"输出累加器,初值 0
一律 fp32 累加
即使 Q/K/V 是 fp16/bf16,累加器必须是 fp32——m, l, Õ 在长序列下会经过几十甚至上百次累加,半精度根本撑不住。FA2 论文 4.3 节专门讨论了这一点。
8.4.4 主循环:分块逐块扫描 K/V
for start_n in range(0, N_CTX, BLOCK_N):
offs_n = start_n + offs_n_base
# 1) 加载 K 块
k_ptrs = K + k_offset + (offs_n[:, None] * stride_kn
+ offs_k[None, :] * stride_kk)
k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX, other=0.0)
# 2) qk = Q_block @ K_block^T
qk = tl.dot(q, tl.trans(k))
qk = qk * sm_scale
qk = tl.where(offs_n[None, :] < N_CTX, qk, -float('inf'))
# 3) 在线 softmax 更新
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
alpha = tl.exp(m_i - m_ij)
p = tl.exp(qk - m_ij[:, None])
l_i = l_i * alpha + tl.sum(p, axis=1)
# 4) 加载 V,更新 acc
v_ptrs = V + v_offset + (offs_n[:, None] * stride_vn
+ offs_k[None, :] * stride_vk)
v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX, other=0.0)
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
# 5) 更新 running max
m_i = m_ij逐步拆解:
(1) 加载 K 块 — 形状 [BLOCK_N, HEAD_DIM],越界用 0 填充。
(2) 计算 qk — 形状 [BLOCK_M, BLOCK_N],乘上 sm_scale = 1/√d。越界列设 -inf 让它在 softmax 里贡献 0。
(3) 在线 softmax — 完全对应 8.3.3 的公式:
m_ij是合并后的 running maxalpha = exp(m_old - m_new)是旧累加器的修正系数p是用m_ij归一化后的本块 softmax 分子l_i用 alpha 修正后加上本块的 sum
(4) p @ V 并入 acc — 同样用 alpha 先修正旧 acc,再加上 p @ V_block。注意 p.to(v.dtype) 把 fp32 的 p 降回 fp16 给 tl.dot(Tensor Core 要求 fp16/bf16 输入)。
(5) 更新 m_i 给下一轮用。
为什么 acc 修正用 alpha[:, None]
alpha 形状是 [BLOCK_M](每行一个),acc 形状是 [BLOCK_M, HEAD_DIM],要广播到 head_dim 维度上。
8.4.5 收尾与写回
acc = acc / l_i[:, None]
o_ptrs = Out + o_offset + (offs_m[:, None] * stride_om
+ offs_k[None, :] * stride_ok)
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=offs_m[:, None] < N_CTX)循环结束后 acc = Σ exp(S_i - m) · V_i,除以 l 就拿到真正的 O = softmax(S) · V。最后用 mask 写回 HBM。
整个核函数从头到尾,S 和 P 都只在寄存器里短暂存在,从未落地 HBM——这就是 FlashAttention 的全部秘密。
8.5 与朴素实现对比
examples/04_flash_attention.py 在 test_correctness() 中会同时跑朴素 attention 与 Triton 实现,期望最大绝对误差在 1e-2 以内(fp16 + 在线累加的典型误差量级)。
o_triton = flash_attention(q, k, v, sm_scale)
o_torch = naive_attention(q, k, v, sm_scale)
assert torch.allclose(o_triton, o_torch, atol=1e-2, rtol=1e-2)fp16 误差
不要期待 atol=1e-6:在线 softmax 涉及 exp 和大量累加,fp16 输入 + fp32 累加的最终误差通常在 1e-3 ~ 1e-2。这对训练完全够用,但写测试时不要照搬向量加法的精度阈值。
8.6 性能分析
8.6.1 算力估计
attention 的 FLOPs 主要在两次 matmul:
$$ \text{FLOPs} \approx 4 \cdot B \cdot H \cdot N^2 \cdot d $$
examples/04_flash_attention.py 的 benchmark 报告 TFLOPS:
flops_per_op = 4 * BATCH * HEADS * N_CTX * N_CTX * HEAD_DIM
perf = lambda ms: flops_per_op * 1e-12 / (ms * 1e-3)8.6.2 官方报告的数据(H100 / B200)
来自官方 06-fused-attention.py:
| 配置 | Triton FP16 | 手写 CUDA FA2 |
|---|---|---|
| H100, d=64, causal, fwd, seq=512 | ~165 TFLOPS | ~165 TFLOPS |
| H100, d=64, causal, fwd, seq=8192 | ~480 TFLOPS | ~422 TFLOPS |
| B200, d=128, fwd, warp_spec, seq=4096 | 526 TFLOPS | 401 TFLOPS |
反超 CUDA 实现
在新架构(Blackwell B200)上 Triton 已经多次出现反超 FA2 手写 CUDA 的报告——这是因为 Triton 编译器能针对新硬件自动调度 warp specialization、TMA 等新特性,手写 CUDA 反而要追着改。
8.6.3 内存节省
| 实现 | seq=4096 单 head 显存 | 训练 seq=8192 是否 OOM |
|---|---|---|
| PyTorch 朴素 | ~32 MB(P 矩阵) | 极易 OOM |
| FlashAttention | ~72 KiB(驻留 tile) | 显存几乎线性增长 |
社区独立测算 (Leandro Lacerda, 2025):FlashAttention 在 4096 token 时,PyTorch 占用 ~3 GB,Triton/CUDA 实现常驻 ~72 MiB。40 倍以上的显存节省。
8.7 进阶扩展
本章简化版只覆盖了基础 forward。完整工业级实现还涉及:
8.7.1 Causal mask
LLM 推理只能看到当前及之前的 token。官方实现用 STAGE 分裂 优化:
- STAGE=1:非对角块——
j < i全部要算,不需要 mask - STAGE=2:对角块——需要 mask,但只有一个块
- STAGE=3:全部(非 causal 路径)
# 简化示意
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk + tl.where(mask, 0, -1.0e6)通过把"需要 mask 的对角块"和"完全不需要 mask 的非对角块"分开 codegen,省掉一组 if 分支。
8.7.2 Backward 与 recomputation
FA2 的 backward 不存 P,而是重新计算:
$$ P = \frac{\exp(QK^\top - m)}{l} $$
代价是多一次 forward 量级的算力,收益是显存 O(N²) → O(N)。Triton 实现分两个核函数:
_attn_bwd_preprocess:算delta = sum(O ⊙ dO)_attn_bwd:同时算 dK、dV、dQ
8.7.3 exp2 trick
硬件原生 exp2 比 exp 快很多。把所有 exp(x) 替换为 exp2(x · log2(e)),并把 log2(e) 直接乘进 sm_scale:
qk_scale = sm_scale * 1.44269504 # 1/ln(2)
p = tl.math.exp2(qk * qk_scale - m_ij)H100 上能带来 ~10% 提升。
8.7.4 Hopper / Blackwell 专属优化
- TMA (Tensor Memory Accelerator):用
tl.make_block_ptr/TensorDescriptor触发异步 DRAM↔SRAM 拷贝 - Warp specialization:把 producer warps(负责 load)和 consumer warps(负责 compute)分组并行,需要
tl.range(..., warp_specialize=True) - FP8:H100/B200 上 fp8 可再提速 ~1.6×
详见官方 06-fused-attention.py 与 FA2 论文。
8.8 反向传播核函数实现
8.4 节只讲了 forward,但训练真正难的是 backward——既要不存 P 矩阵(保持 O(N) 显存),又要算出 dQ, dK, dV 三个梯度。
8.8.1 核心挑战:不存 P 怎么反向?
标准 attention backward 需要 P = softmax(QK^T/√d) 来算梯度:
$$ \frac{\partial L}{\partial V} = P^\top \frac{\partial L}{\partial O}, \quad \frac{\partial L}{\partial Q} = \alpha \frac{\partial L}{\partial S} K, \quad \frac{\partial L}{\partial K} = \alpha \frac{\partial L}{\partial S}^\top Q $$
其中 dS = dP ⊙ P − P ⊙ (rowsum(dP ⊙ P)),dP = dO V^T。
朴素实现要存整个 P ∈ [N, N]——长序列直接 OOM。
8.8.2 Recomputation 策略
FlashAttention 的关键观察:forward 时只存 O ∈ [N, d]、l ∈ [N]、m ∈ [N] 三个张量;backward 时重新计算 P。代价:多一次 forward 量级的算力;收益:显存从 O(N²) 降到 O(N·d)。
具体来说,backward 时利用恒等式:
$$ P_{ij} = \exp(S_{ij} - m_i) / l_i $$
由于 S = QK^T/√d 可以重新算,而 m, l 已保存,于是 P 可以按 block 重新算出来,无需占用 HBM。
进一步还可以用 LSE_i = m_i + log(l_i)(logsumexp)合并两个标量:
$$ P_{ij} = \exp(S_{ij} - LSE_i) $$
省一半元数据存储。FA2 论文的伪代码就用这个版本。
8.8.3 dQ、dK、dV 的分块计算公式
定义 D_i = sum(O_i ⊙ dO_i)(一个 [N] 向量),论文证明:
$$ dS_{ij} = P_{ij} (dP_{ij} - D_i), \quad dP_{ij} = dO_i V_j^\top $$
这样每块 [BLOCK_M, BLOCK_N] 就能独立算 dS block 而不需要全局信息。
FA2 的 backward 算法(来自论文 Algorithm 2,简化):
Preprocess kernel:
D = rowsum(O ⊙ dO) ∈ [N]
写到 HBM
主 backward kernel:
parallel for each K, V block j:
加载 K_j, V_j 进 SRAM
初始化 dK_j = 0, dV_j = 0 (累加器)
for each Q block i:
加载 Q_i, dO_i, LSE_i, D_i 进 SRAM
重算 S_ij = Q_i K_j^T / √d
重算 P_ij = exp(S_ij - LSE_i) # O(BM × BN), 不写回 HBM
dV_j += P_ij^T dO_i # 累加到 dV_j
dP_ij = dO_i V_j^T # 重算 dP
dS_ij = P_ij ⊙ (dP_ij - D_i)
dK_j += dS_ij^T Q_i # 累加到 dK_j
atomic_add(dQ_i, dS_ij K_j) # dQ 跨 j 累加,需 atomic
写回 dK_j, dV_jdQ 必须用 atomic_add
因为 dQ_i 由所有 K/V block 累加而成,而不同 program 处理不同 j,所以需要原子加法(或者反向把 outer loop 放在 Q 上)。FA2 实测 atomic 版本性能损失约 5~10%。
8.8.4 Triton 实现要点(双重循环结构)
官方 06-fused-attention.py 的 backward 拆成两个 kernel:
_attn_bwd_preprocess:算 delta
@triton.jit
def _attn_bwd_preprocess(O, DO, Delta,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, HEAD_DIM)
# load
o = tl.load(O + off_hz * HEAD_DIM * N_CTX
+ off_m[:, None] * HEAD_DIM + off_n[None, :])
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX
+ off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(Delta + off_hz * N_CTX + off_m, delta)_attn_bwd:主反向 kernel(双重循环)
@triton.jit
def _attn_bwd(Q, K, V, sm_scale,
DO, DQ, DK, DV,
M, D, # M = LSE, D = delta from preprocess
stride_z, stride_h, stride_tok, stride_d,
H, N_CTX,
BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, # outer-loop sizes
BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, # inner-loop sizes
HEAD_DIM: tl.constexpr):
bhid = tl.program_id(2)
off_chz = bhid * N_CTX
adj = stride_h * (bhid % H) + stride_z * (bhid // H)
pid = tl.program_id(0)
# 这一段处理 dK / dV:outer loop 在 K, inner loop 在 Q
start_n = pid * BLOCK_N1
start_m = 0
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
# 加载 K_j, V_j 全程驻留 SRAM
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
# inner loop: 对所有 Q block 累加 dK / dV
num_steps = N_CTX // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv, Q, k, v, sm_scale,
DO, M, D,
stride_tok, stride_d, H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM,
start_n, start_m, num_steps,
MASK=False)
tl.store(DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, dv)
tl.store(DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, dk * sm_scale)
# 第二段:处理 dQ:outer loop 在 Q, inner loop 在 K
start_m = pid * BLOCK_M2
offs_m = start_m + tl.arange(0, BLOCK_M2)
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
m = tl.load(M + offs_m)[:, None]
num_steps = N_CTX // BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, do, m, D,
stride_tok, stride_d, H, N_CTX,
BLOCK_M2, BLOCK_N2, HEAD_DIM,
start_m, 0, num_steps, MASK=False)
dq *= sm_scale
tl.store(DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, dq)设计要点:
- 两段式 kernel:第一段算 dK/dV(outer loop 在 K),第二段算 dQ(outer loop 在 Q)。这样避免了 atomic(每个 program 独占输出 tile)。
- 共享 inner-loop helper:
_attn_bwd_dkdv和_attn_bwd_dq都是 inner loop 实现,结构对称便于维护。 - mask 块单独处理:causal attention 把 mask block 和非 mask block 分两段循环,复用 forward 的 STAGE 优化。
- fp32 累加 + 末尾 cast:所有
dq, dk, dv累加器都是 fp32,写回前再降回 fp16/bf16。
8.8.5 性能数据
官方 06-fused-attention.py 在 H100 上的 backward benchmark:
| 配置 | Triton FP16 bwd | FA2 CUDA bwd |
|---|---|---|
| H100, seq=512, d=64, causal | ~120 TFLOPS | ~125 TFLOPS |
| H100, seq=8192, d=64, causal | ~310 TFLOPS | ~290 TFLOPS |
| H100, seq=8192, d=128, causal | ~250 TFLOPS | ~265 TFLOPS |
backward 比 forward 慢约 2~3 倍(额外的重算 + 双向梯度计算),但显存仍是 O(N),这是 FlashAttention 的最大价值。
8.9 FP8 FlashAttention
H100/B200 的 FP8 Tensor Core 提供约 2× 于 FP16 的吞吐(H100 SXM5:989 TFLOPS FP16 → 1979 TFLOPS FP8),但要拿到这个收益必须解决一个核心矛盾:softmax 需要高精度,GEMM 可以低精度。
8.9.1 为什么 FP8 attention 有意义
| 维度 | FP16 | FP8 (E4M3) |
|---|---|---|
| Tensor Core throughput (H100) | 989 TFLOPS | 1979 TFLOPS (2×) |
| Bytes per element | 2 | 1 (HBM 流量 0.5×) |
| 精度(尾数位) | 10 bit | 3 bit |
| 动态范围(指数位) | 5 bit | 4 bit |
理论加速:吞吐 2× + 流量 0.5× → 在 compute-bound 时 2×,memory-bound 时 2×,最坏 1.5×。FA3 论文实测在 H100 上 fp8 比 fp16 快 1.6×,达到 1.3 PFLOPS/s(理论峰值的 65%)。
8.9.2 核心挑战:softmax 精度
softmax 涉及 exp(x) 和大量累加,FP8 的 3 bit 尾数远远不够:
exp(-5) - exp(-6) = 0.00370,FP8 E4M3 直接量化为 0- 一个 attention 行内 N 个
exp值累加,FP8 的最大相对误差可达 10% - 长序列里这个误差会指数级放大
如果直接全 FP8,attention 输出的数值误差 > 5%,模型直接发散。
8.9.3 混合精度策略
FA3 的关键设计:FP8 用在 GEMM、FP32 累加 softmax、FP16 中转:
┌─────────────────────────────────────────────────┐
│ storage in HBM: Q, K, V ∈ FP8 (E4M3) │
│ scales_Q, scales_K, scales_V │
│ (per-block 量化系数) │
├─────────────────────────────────────────────────┤
│ Load to SRAM, dequantize to FP16 / FP32 by: │
│ Q_fp16 = Q_fp8.to(fp16) * scale_Q │
├─────────────────────────────────────────────────┤
│ GEMM (Tensor Core): │
│ S = Q_fp8 @ K_fp8^T (fp32 accumulator) │
│ 用 wgmma.mma_async 直接 FP8×FP8 → FP32 │
├─────────────────────────────────────────────────┤
│ Softmax (multi-function unit): │
│ 全程 FP32 计算 m, l, exp(qk - m) │
│ p_fp32 = exp(s_fp32 - m_fp32) / l_fp32 │
├─────────────────────────────────────────────────┤
│ Quantize p back to FP8 for next GEMM: │
│ p_fp8 = quantize(p_fp32, scale_p_block) │
├─────────────────────────────────────────────────┤
│ GEMM: │
│ O = p_fp8 @ V_fp8 (fp32 accumulator) │
└─────────────────────────────────────────────────┘8.9.4 动态 scaling factor
为防止 FP8 溢出/下溢,每个 block 维护一个 scale:
# 量化:x_fp32 → x_fp8
def quantize_to_fp8(x, scale):
# scale 选 max(|x|) / fp8_max(E4M3 max = 448)
return (x / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
# 反量化:x_fp8 → x_fp32(在 SRAM 内进行)
def dequantize(x_fp8, scale):
return x_fp8.to(torch.float32) * scaleFA3 采用 block quantization:每 BLOCK_M × HEAD_DIM 的 Q 块独立 scale,每 BLOCK_N × HEAD_DIM 的 K/V 块独立 scale。这比 per-tensor scaling 减少 2.6× 的数值误差(FA3 论文表 2)。
8.9.5 Incoherent processing
针对 LLM 的"outlier feature"问题(某几维幅值远大于其他维),FA3 引入 incoherent processing:
# 量化前用一个随机 Hadamard 矩阵 H 旋转 Q, K
Q' = Q @ H, K' = K @ H
# 因为 Q' K'^T = Q H H^T K^T = Q K^T(H 是正交),attention 不变
# 但 Q', K' 的分布被打散,outlier 不再集中在某几维这个 trick 把 FP8 attention 的精度提升到与 FP16 持平(< 0.5% 误差),是 FA3 能在生产环境落地的关键。
8.9.6 Triton 中的 FP8 实现
Triton 从 3.x 开始原生支持 tl.float8e4nv(E4M3)和 tl.float8e5(E5M2):
@triton.jit
def fa3_fp8_inner(q_fp8, scale_q,
k_ptr, v_ptr, scale_k_ptr, scale_v_ptr,
...):
# 加载 K block 并反量化到 fp16 for softmax
k_fp8 = tl.load(k_ptr + ...) # tl.float8e4nv tensor
scale_k = tl.load(scale_k_ptr + ...) # fp32 scalar
# FP8 × FP8 GEMM,fp32 累加(Tensor Core 直接支持)
qk_fp32 = tl.dot(q_fp8, tl.trans(k_fp8), out_dtype=tl.float32)
qk_fp32 = qk_fp32 * (scale_q * scale_k * sm_scale)
# 在线 softmax,全程 fp32
m_ij = tl.maximum(m_i, tl.max(qk_fp32, axis=1))
alpha = tl.math.exp2((m_i - m_ij) * 1.44269504)
p_fp32 = tl.math.exp2((qk_fp32 - m_ij[:, None]) * 1.44269504)
l_i = l_i * alpha + tl.sum(p_fp32, axis=1)
# 量化 P 到 FP8 用于第二个 GEMM
p_max = tl.max(tl.abs(p_fp32)) # 简化版:全局 scale;FA3 用 block scale
scale_p = p_max / 448.0
p_fp8 = (p_fp32 / scale_p).to(tl.float8e4nv)
v_fp8 = tl.load(v_ptr + ...)
scale_v = tl.load(scale_v_ptr + ...)
acc_fp32 = acc_fp32 * alpha[:, None]
acc_fp32 = tl.dot(p_fp8, v_fp8, acc_fp32, out_dtype=tl.float32)
acc_fp32 = acc_fp32 * (scale_p * scale_v)FP8 attention 仅限 H100+
FP8 Tensor Core 是 Hopper SM90 引入的,A100(SM80)和 L40S(SM89)都不支持。在 A100 上跑会自动 fallback 到 fp16,没有任何收益。
8.10 Warp Specialization 版本对比
FA2 和 FA3 的最大差异不在算法,而在 执行模型——FA3 引入了 producer-consumer 异步流水。
8.10.1 FA2:数据并行(单一执行流)
每个 warp 都做相同的事:
Time →
warp 0: [load K] → [compute QK] → [softmax] → [load V] → [compute PV]
warp 1: [load K] → [compute QK] → [softmax] → [load V] → [compute PV]
warp 2: [load K] → [compute QK] → [softmax] → [load V] → [compute PV]
warp 3: [load K] → [compute QK] → [softmax] → [load V] → [compute PV]问题:每个 warp 在 load 期间 Tensor Core 空转,在 compute 期间 LSU 空转。H100 测得这种模式利用率约 35%(FA3 论文)。
8.10.2 FA3:Warp Specialization(异步流水)
Producer warps (1 warp group = 4 warps):
Time →
[TMA load K_0] → [TMA load V_0] → [TMA load K_1] → [TMA load V_1] → ...
(专门发射 TMA load 指令,几乎不消耗 Tensor Core)
Consumer warps (1~2 warp groups):
Time →
[wait K_0] → [WGMMA QK] → [softmax_0] → [wait V_0] → [WGMMA PV] → ...
(专门做 WGMMA + softmax,从 SRAM 拿数据)
两组 warp 通过 mbarrier 同步,pipeline 全程并行!效率提升来自三点:
- TMA 异步:producer 发了 TMA load 指令后立即返回,Tensor Memory Accelerator 硬件单元后台搬数据
- WGMMA 异步:consumer 发了 wgmma.mma_async 后立即返回,Tensor Core 后台计算
- 寄存器再分配:Hopper 的
setmaxnreg指令让 producer 用少量寄存器(只发指令),consumer 拿到 90%+ 寄存器用于累加
8.10.3 Ping-pong 调度
FA3 在 consumer 内部还引入 ping-pong:双 buffer 让 GEMM 和 softmax 并行:
SRAM 双 buffer:
Buffer A: K_0, V_0
Buffer B: K_1, V_1
Consumer warpgroup 1: GEMM on A → softmax on A → GEMM on A2
Consumer warpgroup 2: GEMM on B → softmax on B → GEMM on B2
↑ ↑
两组 consumer 错开半个周期,互相隐藏 latency关键洞察:H100 的 special function unit (做 exp) 只有 3.9 TFLOPS,而 Tensor Core 有 989 TFLOPS(256× 差距)。Ping-pong 让两个 consumer 错开,让 GEMM 和 softmax 真正并行——一个 warpgroup 在算 GEMM 时,另一个在算 softmax。
8.10.4 吞吐量对比数据
H100 SXM5 实测(来自 FA3 论文 + Together AI 博客):
| 版本 | seq=8192 d=128 fp16 fwd | 利用率 | bwd |
|---|---|---|---|
| FA2 (CUDA) | ~350 TFLOPS | 35% | ~250 TFLOPS |
| FA3 (CUDA) | ~750 TFLOPS | 75% | ~530 TFLOPS |
| FA3 FP8 | ~1300 TFLOPS (1.3 PFLOPS) | 65% (FP8 peak) | — |
| Triton 06-fused-attention (v2 风格) | ~480 TFLOPS | 48% | ~310 TFLOPS |
Triton + warp_specialize=True | ~620 TFLOPS | 62% | — |
Triton 中启用 warp specialization
# 在 attention inner loop 加 warp_specialize=True
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=True):
...Triton 编译器会自动尝试拆分 producer / consumer warps。但要拿到 FA3 全部收益(ping-pong 等)还需要 Triton 3.4+,并配合 TensorDescriptor(TMA)。当前(2026 年 5 月)Triton 的 warp_specialize 已能拿到 ~85% FA3 CUDA 性能,相比 FA2 已有 1.3~1.5× 提升。
8.10.5 Triton 中的实现差异
| 维度 | Triton FA2 风格 | Triton FA3 风格 |
|---|---|---|
| 加载 K/V | tl.load(k_ptrs, ...) | desc_k.load([off, 0])(TensorDescriptor / TMA) |
| 主循环 | for start_n in range(...) | for start_n in tl.range(..., warp_specialize=True) |
| Tensor Core 指令 | mma.sync (Ampere) / wgmma(Hopper) | wgmma (async) |
| 同步 | 隐式(每条指令同步) | 显式 mbarrier(编译器生成) |
| 适用 GPU | A100 / H100 / B200 | 仅 H100+(需要 SM90 TMA) |
| 代码改动 | 少量 | 主循环重写 + TensorDescriptor 替换指针 |
实操建议:
- A100 / 旧硬件:继续用 FA2 风格 + 手写指针
- H100 + Triton 3.4+:尝试
warp_specialize=True,性能能拿到 FA3 CUDA 85% - B200:必须用 TensorDescriptor + warp_specialize,因为 SM100 的 TMA tile execution 要求
本章小结
- 朴素 attention 受限于
O(N²)的 HBM 访问与显存,长序列下 GPU 算力被白白浪费。 - FlashAttention 通过分块 (tiling) + 在线 softmax + 重算 (recomputation),把 HBM 访问降到
O(N),显存降到O(N)。 - 在线 softmax 的核心是:新块到来时用
α = exp(m_old - m_new)修正旧的累加器(l, Õ),循环结束后O = Õ / l。 - Triton 实现把 Q 块全程驻留 SRAM (共享内存),K/V 在内循环滑动,所有中间张量
S, P都不落地 HBM。 - Backward 用 recomputation:forward 只存
O, m, l三个 O(N) 张量,backward 时按 block 重算 P,显存仍 O(N);Triton 实现是 preprocess + 两段 inner loop 的双 kernel 结构。 - FP8 attention 通过 block quantization + incoherent processing,在 H100 上达到 1.3 PFLOPS(fp16 的 1.6×),同时精度损失可控(< 0.5% vs fp16)。
- FA3 的 warp specialization 把 producer(TMA load)和 consumer(WGMMA + softmax)拆成异步流水,H100 上从 FA2 的 35% 利用率提升到 75%;Triton 通过
warp_specialize=True能拿到约 85% 的 FA3 CUDA 性能。 - 性能上 Triton 实现已接近甚至反超手写 CUDA FA2,并且代码量约为 CUDA 的 1/4,是研究迭代和算子定制的首选。
配套代码
本章所有代码片段对应 examples/04_flash_attention.py,建议在 GPU 机器上完整跑一遍 test_correctness() 与 benchmark(),亲手观察长序列下的加速比。
思考题
关于 fp32 累加:本章核函数里
m_i、l_i、acc全部用 fp32 dtype,但 Q/K/V 是 fp16。如果把acc改成 fp16,会出什么问题?请结合N_CTX = 8192这种长序列估算最大累加误差。关于 alpha 修正:假设主循环跑到第 5 个 K/V 块,前 4 块的 max 都是
-2.0,第 5 块出现了max = 10.0。求此时alpha = exp(m_old - m_new)的具体数值,并说明这个数值在 acc 修正中的物理含义。关于 v1 vs v2:FlashAttention v1 把 outer loop 放在 K/V、inner loop 放在 Q;v2 反过来。试着画出两种调度下,Q、K、V 分块在 SRAM 中的进出次数,并解释为什么 v2 在长序列下快 ~2×。
下一章我们将系统总结 Triton 开发中的常用模式、性能优化清单、调试技巧与社区资源——把前八章学到的散点经验汇成一张可对照的速查表。