Skip to content

8. FlashAttention 实战

用 Triton 一步步实现 FlashAttention,理解推动大模型长序列训练的关键算法。

FlashAttention(Tri Dao 等人,2022 / 2023)是近年最具影响力的 GPU 算法之一。它通过 分块计算 + 在线 softmax + 不 materialize 注意力矩阵,把 attention 的 HBM 访问从 O(N²) 降到 O(N)、显存从 O(N²) 降到 O(N),使得 GPT-3.5/4、LLaMA、Mistral 等模型能在合理的显存预算下训练数千乃至上万 token 的上下文。

本章带你从问题分析出发,一步步重写出官方 06-fused-attention.py 的核心思路,并对照配套示例 examples/04_flash_attention.py 逐段拆解。

本章内容概览

  • 标准 attention 的内存瓶颈分析
  • FlashAttention 的核心思想:分块 (tiling) + 在线 softmax
  • 在线 softmax 算法推导
  • Triton 实现详解(核函数逐段讲解)
  • 性能分析与版本演进
  • 进一步扩展:causal mask、backward、FP8

8.1 标准 Attention 的内存瓶颈

设 Q、K、V 形状均为 [N, d](N = seq_len,d = head_dim),标准 scaled-dot-product attention 的算式是:

$$ S = \frac{QK^\top}{\sqrt{d}}, \quad P = \mathrm{softmax}(S), \quad O = PV $$

朴素 PyTorch 实现:

python
def naive_attention(q, k, v, sm_scale):
    s = torch.matmul(q, k.transpose(-2, -1)) * sm_scale  # [N, N]
    p = torch.softmax(s, dim=-1)                         # [N, N]
    o = torch.matmul(p, v)                               # [N, d]
    return o

8.1.1 内存爆炸的本质

中间张量 SP 的形状是 [N, N]。当 N = 4096、单 head、fp16 时:

$$ 4096 \times 4096 \times 2,\text{B} = 32,\text{MB} $$

而一张 H100 SM 的共享内存只有 228 KB,A100 是 192 KB。S 根本放不进 SRAM,必须落到 HBM,于是:

步骤HBM 读HBM 写
S = QK^T读 Q (Nd), 读 K (Nd)写 S ()
P = softmax(S)读 S ()写 P ()
O = PV读 P (), 读 V (Nd)写 O (Nd)

合计 HBM 访问量约 3N² + 4Nd,随序列长度 平方级增长。当 N 从 1024 → 8192 时 HBM 流量增加 64 倍,而算力 (FLOPs) 同样增加 64 倍——也就是说,算力增长被白白浪费在搬数据上,attention 在长序列下从计算密集型沦为访存密集型。

关键瓶颈

不是 GPU 算不动 attention,而是 HBM 带宽喂不饱 Tensor Core

8.1.2 显存峰值

更糟糕的是,训练时需要在 backward 阶段重用 P,朴素实现必须把 [B, H, N, N]P 矩阵保留到 backward,显存占用直接随 爆炸。8192 token、32 head、batch=4、fp16 单是 P 就要 64 GB——这就是为什么早期 GPT-2/3 训练 context 长度卡在 1K~2K。


8.2 FlashAttention 的核心思想

Tri Dao 在论文里给出了三个观察:

  1. 不需要看到完整的 S 才能做 softmax——softmax 可以"流式"地一块块算(online softmax)。
  2. 如果 P 不存储,O 也能直接算出来——把 P 留在寄存器里立刻乘 V。
  3. 如果 backward 不存 P,可以用 (m, l) 重算——多算一次 forward,省下 O(N²) 显存。

合起来:

FlashAttention 三连

Tiling(分块) + Online softmax(在线归一化) + Recomputation(重算)

效果:

  • HBM 访问 O(N²)O(N²d / M),M 是 SRAM 大小,实际等价于 O(N)
  • 显存 O(N²)O(N)(只存最终的 O 和 forward 阶段的 m, l
  • 算力利用率:A100 上 fp16 attention 从 ~30% peak → ~70% peak

8.3 Online Softmax 推导

这是整个算法的"魔法"所在,必须搞清楚。

8.3.1 数值稳定的标准 softmax

给定向量 x = [x_1, ..., x_n],softmax 通常这样算:

m = max(x)                  # 减去最大值,防 exp 溢出
p_i = exp(x_i - m)
l = sum(p_i)
y_i = p_i / l

这要 三次扫描:算 max、算 exp+sum、算 div。但我们能不能 一次扫描 就出结果?

8.3.2 增量公式

x 切成两段 x = [x_a, x_b]。先用 x_a 算出局部统计量:

m_a = max(x_a)
l_a = sum(exp(x_a - m_a))

再看到 x_b,怎样合并?设:

m_b = max(x_b)
l_b = sum(exp(x_b - m_b))

合并后的全局最大值:

$$ m_{\text{new}} = \max(m_a, m_b) $$

我们需要把两段的 l 都"对齐"到 m_new

$$ l_{\text{new}} = e^{m_a - m_{\text{new}}} \cdot l_a + e^{m_b - m_{\text{new}}} \cdot l_b $$

这就是关键——旧的 l 只要乘一个修正系数 α = exp(m_old - m_new) 就能继续用,不必从头扫一遍。

8.3.3 把 O 一起推上去

attention 真正要的不是 softmax 本身,而是 O = softmax(S) · V。把 V 也按行分块,记 V 的对应块为 V_a, V_b,并定义"未归一化"的输出累加器:

$$ \tilde O = \sum_i e^{S_i - m} \cdot V_i $$

最终 O = \tilde O / l。增量公式变成:

m_new = max(m_old, max(S_block))
α     = exp(m_old - m_new)
p     = exp(S_block - m_new)             # 局部归一化的 p
l_new = α * l_old + sum(p)
Õ_new = α * Õ_old + p @ V_block

一句话总结

新块到来时:把旧的 (l, Õ) 乘上修正因子 α,再加上新块的贡献。 循环结束后 O = Õ / l

这就是 FlashAttention forward 的全部数学。所有"看起来很复杂"的工程细节,都在围绕这几行公式做高效的 GPU 实现。


8.4 Triton 实现总览

我们对照 examples/04_flash_attention.py 一节节看。该示例是教学简化版(forward only、非 causal、不启用 warp_specialize / TMA),逻辑上是官方 06-fused-attention.py 的真子集,便于聚焦核心算法。

8.4.1 网格与程序实例划分

python
grid = (triton.cdiv(N_CTX, BLOCK_M), BATCH * HEADS, 1)
  • axis=0:沿 Q 方向的块索引——每个程序实例负责输出 O 的一块 [BLOCK_M, HEAD_DIM]
  • axis=1:把 batch 和 head 平铺成一维——同一个程序实例只看一个 head 的数据

为什么 Q 分块、KV 不分块

FlashAttention v2 的关键改进就是把 outer loop 放在 Q 上、inner loop 放在 KV 上。这样 Q 块被加载一次后就可以全程驻留 SRAM;KV 在内循环中滑动,只占用一份分块的 SRAM。v1 反过来,导致 Q 反复 reload,吞吐低 ~2 倍。

8.4.2 加载 Q 并驻留 SRAM

python
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)  # 当前 Q 行索引
offs_k = tl.arange(0, HEAD_DIM)                     # head_dim 列索引

q_ptrs = Q + q_offset + (offs_m[:, None] * stride_qm
                       + offs_k[None, :] * stride_qk)
q_mask = offs_m[:, None] < N_CTX
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
  • offs_m[:, None] * stride_m + offs_k[None, :] * stride_k 这种广播写法构造 2D 指针块——Triton 中所有"二维分块"都是这么生成的。
  • q_mask 处理 N_CTX 不是 BLOCK_M 整数倍的边界。
  • 这个 q 张量在整个 K/V 循环中一直驻留寄存器,不再回 DRAM——这是 FA2 的核心优化。

8.4.3 初始化在线 softmax 累加器

python
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

对应 8.3 节的 m, l, Õ

  • m_i:每行的 running max,初值 -inf
  • l_i:每行的 running sum of exp,初值 0
  • acc:每行的"未归一化"输出累加器,初值 0

一律 fp32 累加

即使 Q/K/V 是 fp16/bf16,累加器必须是 fp32——m, l, Õ 在长序列下会经过几十甚至上百次累加,半精度根本撑不住。FA2 论文 4.3 节专门讨论了这一点。

8.4.4 主循环:分块逐块扫描 K/V

python
for start_n in range(0, N_CTX, BLOCK_N):
    offs_n = start_n + offs_n_base

    # 1) 加载 K 块
    k_ptrs = K + k_offset + (offs_n[:, None] * stride_kn
                           + offs_k[None, :] * stride_kk)
    k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX, other=0.0)

    # 2) qk = Q_block @ K_block^T
    qk = tl.dot(q, tl.trans(k))
    qk = qk * sm_scale
    qk = tl.where(offs_n[None, :] < N_CTX, qk, -float('inf'))

    # 3) 在线 softmax 更新
    m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
    alpha = tl.exp(m_i - m_ij)
    p     = tl.exp(qk - m_ij[:, None])
    l_i   = l_i * alpha + tl.sum(p, axis=1)

    # 4) 加载 V,更新 acc
    v_ptrs = V + v_offset + (offs_n[:, None] * stride_vn
                           + offs_k[None, :] * stride_vk)
    v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX, other=0.0)
    acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)

    # 5) 更新 running max
    m_i = m_ij

逐步拆解:

(1) 加载 K 块 — 形状 [BLOCK_N, HEAD_DIM],越界用 0 填充。

(2) 计算 qk — 形状 [BLOCK_M, BLOCK_N],乘上 sm_scale = 1/√d。越界列设 -inf 让它在 softmax 里贡献 0。

(3) 在线 softmax — 完全对应 8.3.3 的公式:

  • m_ij 是合并后的 running max
  • alpha = exp(m_old - m_new) 是旧累加器的修正系数
  • p 是用 m_ij 归一化后的本块 softmax 分子
  • l_i 用 alpha 修正后加上本块的 sum

(4) p @ V 并入 acc — 同样用 alpha 先修正旧 acc,再加上 p @ V_block。注意 p.to(v.dtype) 把 fp32 的 p 降回 fp16 给 tl.dot(Tensor Core 要求 fp16/bf16 输入)。

(5) 更新 m_i 给下一轮用。

为什么 acc 修正用 alpha[:, None]

alpha 形状是 [BLOCK_M](每行一个),acc 形状是 [BLOCK_M, HEAD_DIM],要广播到 head_dim 维度上。

8.4.5 收尾与写回

python
acc = acc / l_i[:, None]

o_ptrs = Out + o_offset + (offs_m[:, None] * stride_om
                         + offs_k[None, :] * stride_ok)
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=offs_m[:, None] < N_CTX)

循环结束后 acc = Σ exp(S_i - m) · V_i,除以 l 就拿到真正的 O = softmax(S) · V。最后用 mask 写回 HBM。

整个核函数从头到尾,S 和 P 都只在寄存器里短暂存在,从未落地 HBM——这就是 FlashAttention 的全部秘密。


8.5 与朴素实现对比

examples/04_flash_attention.pytest_correctness() 中会同时跑朴素 attention 与 Triton 实现,期望最大绝对误差在 1e-2 以内(fp16 + 在线累加的典型误差量级)。

python
o_triton = flash_attention(q, k, v, sm_scale)
o_torch  = naive_attention(q, k, v, sm_scale)
assert torch.allclose(o_triton, o_torch, atol=1e-2, rtol=1e-2)

fp16 误差

不要期待 atol=1e-6:在线 softmax 涉及 exp 和大量累加,fp16 输入 + fp32 累加的最终误差通常在 1e-3 ~ 1e-2。这对训练完全够用,但写测试时不要照搬向量加法的精度阈值。


8.6 性能分析

8.6.1 算力估计

attention 的 FLOPs 主要在两次 matmul:

$$ \text{FLOPs} \approx 4 \cdot B \cdot H \cdot N^2 \cdot d $$

examples/04_flash_attention.py 的 benchmark 报告 TFLOPS:

python
flops_per_op = 4 * BATCH * HEADS * N_CTX * N_CTX * HEAD_DIM
perf = lambda ms: flops_per_op * 1e-12 / (ms * 1e-3)

8.6.2 官方报告的数据(H100 / B200)

来自官方 06-fused-attention.py

配置Triton FP16手写 CUDA FA2
H100, d=64, causal, fwd, seq=512~165 TFLOPS~165 TFLOPS
H100, d=64, causal, fwd, seq=8192~480 TFLOPS~422 TFLOPS
B200, d=128, fwd, warp_spec, seq=4096526 TFLOPS401 TFLOPS

反超 CUDA 实现

在新架构(Blackwell B200)上 Triton 已经多次出现反超 FA2 手写 CUDA 的报告——这是因为 Triton 编译器能针对新硬件自动调度 warp specialization、TMA 等新特性,手写 CUDA 反而要追着改。

8.6.3 内存节省

实现seq=4096 单 head 显存训练 seq=8192 是否 OOM
PyTorch 朴素~32 MB(P 矩阵)极易 OOM
FlashAttention~72 KiB(驻留 tile)显存几乎线性增长

社区独立测算 (Leandro Lacerda, 2025):FlashAttention 在 4096 token 时,PyTorch 占用 ~3 GB,Triton/CUDA 实现常驻 ~72 MiB。40 倍以上的显存节省


8.7 进阶扩展

本章简化版只覆盖了基础 forward。完整工业级实现还涉及:

8.7.1 Causal mask

LLM 推理只能看到当前及之前的 token。官方实现用 STAGE 分裂 优化:

  • STAGE=1:非对角块——j < i 全部要算,不需要 mask
  • STAGE=2:对角块——需要 mask,但只有一个块
  • STAGE=3:全部(非 causal 路径)
python
# 简化示意
if STAGE == 2:
    mask = offs_m[:, None] >= (start_n + offs_n[None, :])
    qk = qk + tl.where(mask, 0, -1.0e6)

通过把"需要 mask 的对角块"和"完全不需要 mask 的非对角块"分开 codegen,省掉一组 if 分支。

8.7.2 Backward 与 recomputation

FA2 的 backward 不存 P,而是重新计算:

$$ P = \frac{\exp(QK^\top - m)}{l} $$

代价是多一次 forward 量级的算力,收益是显存 O(N²) → O(N)。Triton 实现分两个核函数:

  • _attn_bwd_preprocess:算 delta = sum(O ⊙ dO)
  • _attn_bwd:同时算 dK、dV、dQ

8.7.3 exp2 trick

硬件原生 exp2exp 快很多。把所有 exp(x) 替换为 exp2(x · log2(e)),并把 log2(e) 直接乘进 sm_scale

python
qk_scale = sm_scale * 1.44269504  # 1/ln(2)
p = tl.math.exp2(qk * qk_scale - m_ij)

H100 上能带来 ~10% 提升。

8.7.4 Hopper / Blackwell 专属优化

  • TMA (Tensor Memory Accelerator):用 tl.make_block_ptr / TensorDescriptor 触发异步 DRAM↔SRAM 拷贝
  • Warp specialization:把 producer warps(负责 load)和 consumer warps(负责 compute)分组并行,需要 tl.range(..., warp_specialize=True)
  • FP8:H100/B200 上 fp8 可再提速 ~1.6×

详见官方 06-fused-attention.py 与 FA2 论文。


8.8 反向传播核函数实现

8.4 节只讲了 forward,但训练真正难的是 backward——既要不存 P 矩阵(保持 O(N) 显存),又要算出 dQ, dK, dV 三个梯度。

8.8.1 核心挑战:不存 P 怎么反向?

标准 attention backward 需要 P = softmax(QK^T/√d) 来算梯度:

$$ \frac{\partial L}{\partial V} = P^\top \frac{\partial L}{\partial O}, \quad \frac{\partial L}{\partial Q} = \alpha \frac{\partial L}{\partial S} K, \quad \frac{\partial L}{\partial K} = \alpha \frac{\partial L}{\partial S}^\top Q $$

其中 dS = dP ⊙ P − P ⊙ (rowsum(dP ⊙ P))dP = dO V^T

朴素实现要存整个 P ∈ [N, N]——长序列直接 OOM。

8.8.2 Recomputation 策略

FlashAttention 的关键观察:forward 时只存 O ∈ [N, d]l ∈ [N]m ∈ [N] 三个张量;backward 时重新计算 P。代价:多一次 forward 量级的算力;收益:显存从 O(N²) 降到 O(N·d)

具体来说,backward 时利用恒等式:

$$ P_{ij} = \exp(S_{ij} - m_i) / l_i $$

由于 S = QK^T/√d 可以重新算,而 m, l 已保存,于是 P 可以按 block 重新算出来,无需占用 HBM。

进一步还可以用 LSE_i = m_i + log(l_i)(logsumexp)合并两个标量:

$$ P_{ij} = \exp(S_{ij} - LSE_i) $$

省一半元数据存储。FA2 论文的伪代码就用这个版本。

8.8.3 dQ、dK、dV 的分块计算公式

定义 D_i = sum(O_i ⊙ dO_i)(一个 [N] 向量),论文证明:

$$ dS_{ij} = P_{ij} (dP_{ij} - D_i), \quad dP_{ij} = dO_i V_j^\top $$

这样每块 [BLOCK_M, BLOCK_N] 就能独立算 dS block 而不需要全局信息。

FA2 的 backward 算法(来自论文 Algorithm 2,简化):

text
Preprocess kernel:
  D = rowsum(O ⊙ dO)   ∈ [N]
  写到 HBM

主 backward kernel:
  parallel for each K, V block j:
    加载 K_j, V_j 进 SRAM
    初始化 dK_j = 0, dV_j = 0   (累加器)
    for each Q block i:
      加载 Q_i, dO_i, LSE_i, D_i 进 SRAM
      重算 S_ij = Q_i K_j^T / √d
      重算 P_ij = exp(S_ij - LSE_i)         # O(BM × BN), 不写回 HBM
      dV_j += P_ij^T dO_i                    # 累加到 dV_j
      dP_ij = dO_i V_j^T                     # 重算 dP
      dS_ij = P_ij ⊙ (dP_ij - D_i)
      dK_j += dS_ij^T Q_i                    # 累加到 dK_j
      atomic_add(dQ_i, dS_ij K_j)            # dQ 跨 j 累加,需 atomic
    写回 dK_j, dV_j

dQ 必须用 atomic_add

因为 dQ_i 由所有 K/V block 累加而成,而不同 program 处理不同 j,所以需要原子加法(或者反向把 outer loop 放在 Q 上)。FA2 实测 atomic 版本性能损失约 5~10%。

8.8.4 Triton 实现要点(双重循环结构)

官方 06-fused-attention.py 的 backward 拆成两个 kernel:

_attn_bwd_preprocess:算 delta
python
@triton.jit
def _attn_bwd_preprocess(O, DO, Delta,
                         Z, H, N_CTX,
                         BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr):
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_hz = tl.program_id(1)
    off_n = tl.arange(0, HEAD_DIM)
    # load
    o = tl.load(O + off_hz * HEAD_DIM * N_CTX
                + off_m[:, None] * HEAD_DIM + off_n[None, :])
    do = tl.load(DO + off_hz * HEAD_DIM * N_CTX
                 + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
    delta = tl.sum(o * do, axis=1)
    tl.store(Delta + off_hz * N_CTX + off_m, delta)
_attn_bwd:主反向 kernel(双重循环)
python
@triton.jit
def _attn_bwd(Q, K, V, sm_scale,
              DO, DQ, DK, DV,
              M, D,              # M = LSE, D = delta from preprocess
              stride_z, stride_h, stride_tok, stride_d,
              H, N_CTX,
              BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr,  # outer-loop sizes
              BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr,  # inner-loop sizes
              HEAD_DIM: tl.constexpr):
    bhid = tl.program_id(2)
    off_chz = bhid * N_CTX
    adj = stride_h * (bhid % H) + stride_z * (bhid // H)
    pid = tl.program_id(0)

    # 这一段处理 dK / dV:outer loop 在 K, inner loop 在 Q
    start_n = pid * BLOCK_N1
    start_m = 0
    MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
    offs_n = start_n + tl.arange(0, BLOCK_N1)

    dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)

    # 加载 K_j, V_j 全程驻留 SRAM
    k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
    v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)

    # inner loop: 对所有 Q block 累加 dK / dV
    num_steps = N_CTX // MASK_BLOCK_M1
    dk, dv = _attn_bwd_dkdv(dk, dv, Q, k, v, sm_scale,
                            DO, M, D,
                            stride_tok, stride_d, H, N_CTX,
                            MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM,
                            start_n, start_m, num_steps,
                            MASK=False)

    tl.store(DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, dv)
    tl.store(DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, dk * sm_scale)

    # 第二段:处理 dQ:outer loop 在 Q, inner loop 在 K
    start_m = pid * BLOCK_M2
    offs_m = start_m + tl.arange(0, BLOCK_M2)
    q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
    dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
    do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
    m = tl.load(M + offs_m)[:, None]

    num_steps = N_CTX // BLOCK_N2
    dq = _attn_bwd_dq(dq, q, K, V, do, m, D,
                      stride_tok, stride_d, H, N_CTX,
                      BLOCK_M2, BLOCK_N2, HEAD_DIM,
                      start_m, 0, num_steps, MASK=False)

    dq *= sm_scale
    tl.store(DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, dq)

设计要点:

  1. 两段式 kernel:第一段算 dK/dV(outer loop 在 K),第二段算 dQ(outer loop 在 Q)。这样避免了 atomic(每个 program 独占输出 tile)。
  2. 共享 inner-loop helper_attn_bwd_dkdv_attn_bwd_dq 都是 inner loop 实现,结构对称便于维护。
  3. mask 块单独处理:causal attention 把 mask block 和非 mask block 分两段循环,复用 forward 的 STAGE 优化。
  4. fp32 累加 + 末尾 cast:所有 dq, dk, dv 累加器都是 fp32,写回前再降回 fp16/bf16。

8.8.5 性能数据

官方 06-fused-attention.py 在 H100 上的 backward benchmark:

配置Triton FP16 bwdFA2 CUDA bwd
H100, seq=512, d=64, causal~120 TFLOPS~125 TFLOPS
H100, seq=8192, d=64, causal~310 TFLOPS~290 TFLOPS
H100, seq=8192, d=128, causal~250 TFLOPS~265 TFLOPS

backward 比 forward 慢约 2~3 倍(额外的重算 + 双向梯度计算),但显存仍是 O(N),这是 FlashAttention 的最大价值。

8.9 FP8 FlashAttention

H100/B200 的 FP8 Tensor Core 提供约 2× 于 FP16 的吞吐(H100 SXM5:989 TFLOPS FP16 → 1979 TFLOPS FP8),但要拿到这个收益必须解决一个核心矛盾:softmax 需要高精度,GEMM 可以低精度

8.9.1 为什么 FP8 attention 有意义

维度FP16FP8 (E4M3)
Tensor Core throughput (H100)989 TFLOPS1979 TFLOPS (2×)
Bytes per element21 (HBM 流量 0.5×)
精度(尾数位)10 bit3 bit
动态范围(指数位)5 bit4 bit

理论加速:吞吐 2× + 流量 0.5× → 在 compute-bound 时 2×,memory-bound 时 2×,最坏 1.5×。FA3 论文实测在 H100 上 fp8 比 fp16 快 1.6×,达到 1.3 PFLOPS/s(理论峰值的 65%)。

8.9.2 核心挑战:softmax 精度

softmax 涉及 exp(x) 和大量累加,FP8 的 3 bit 尾数远远不够:

  • exp(-5) - exp(-6) = 0.00370,FP8 E4M3 直接量化为 0
  • 一个 attention 行内 N 个 exp 值累加,FP8 的最大相对误差可达 10%
  • 长序列里这个误差会指数级放大

如果直接全 FP8,attention 输出的数值误差 > 5%,模型直接发散。

8.9.3 混合精度策略

FA3 的关键设计:FP8 用在 GEMM、FP32 累加 softmax、FP16 中转

text
┌─────────────────────────────────────────────────┐
│  storage in HBM:    Q, K, V ∈ FP8 (E4M3)         │
│                     scales_Q, scales_K, scales_V │
│                     (per-block 量化系数)         │
├─────────────────────────────────────────────────┤
│  Load to SRAM, dequantize to FP16 / FP32 by:    │
│    Q_fp16 = Q_fp8.to(fp16) * scale_Q             │
├─────────────────────────────────────────────────┤
│  GEMM (Tensor Core):                            │
│    S = Q_fp8 @ K_fp8^T  (fp32 accumulator)      │
│    用 wgmma.mma_async 直接 FP8×FP8 → FP32       │
├─────────────────────────────────────────────────┤
│  Softmax (multi-function unit):                 │
│    全程 FP32 计算 m, l, exp(qk - m)              │
│    p_fp32 = exp(s_fp32 - m_fp32) / l_fp32        │
├─────────────────────────────────────────────────┤
│  Quantize p back to FP8 for next GEMM:          │
│    p_fp8 = quantize(p_fp32, scale_p_block)      │
├─────────────────────────────────────────────────┤
│  GEMM:                                          │
│    O = p_fp8 @ V_fp8  (fp32 accumulator)        │
└─────────────────────────────────────────────────┘

8.9.4 动态 scaling factor

为防止 FP8 溢出/下溢,每个 block 维护一个 scale:

python
# 量化:x_fp32 → x_fp8
def quantize_to_fp8(x, scale):
    # scale 选 max(|x|) / fp8_max(E4M3 max = 448)
    return (x / scale).clamp(-448, 448).to(torch.float8_e4m3fn)

# 反量化:x_fp8 → x_fp32(在 SRAM 内进行)
def dequantize(x_fp8, scale):
    return x_fp8.to(torch.float32) * scale

FA3 采用 block quantization:每 BLOCK_M × HEAD_DIM 的 Q 块独立 scale,每 BLOCK_N × HEAD_DIM 的 K/V 块独立 scale。这比 per-tensor scaling 减少 2.6× 的数值误差(FA3 论文表 2)。

8.9.5 Incoherent processing

针对 LLM 的"outlier feature"问题(某几维幅值远大于其他维),FA3 引入 incoherent processing

python
# 量化前用一个随机 Hadamard 矩阵 H 旋转 Q, K
Q' = Q @ H,  K' = K @ H
# 因为 Q' K'^T = Q H H^T K^T = Q K^T(H 是正交),attention 不变
# 但 Q', K' 的分布被打散,outlier 不再集中在某几维

这个 trick 把 FP8 attention 的精度提升到与 FP16 持平(< 0.5% 误差),是 FA3 能在生产环境落地的关键。

8.9.6 Triton 中的 FP8 实现

Triton 从 3.x 开始原生支持 tl.float8e4nv(E4M3)和 tl.float8e5(E5M2):

python
@triton.jit
def fa3_fp8_inner(q_fp8, scale_q,
                  k_ptr, v_ptr, scale_k_ptr, scale_v_ptr,
                  ...):
    # 加载 K block 并反量化到 fp16 for softmax
    k_fp8 = tl.load(k_ptr + ...)            # tl.float8e4nv tensor
    scale_k = tl.load(scale_k_ptr + ...)    # fp32 scalar
    # FP8 × FP8 GEMM,fp32 累加(Tensor Core 直接支持)
    qk_fp32 = tl.dot(q_fp8, tl.trans(k_fp8), out_dtype=tl.float32)
    qk_fp32 = qk_fp32 * (scale_q * scale_k * sm_scale)

    # 在线 softmax,全程 fp32
    m_ij = tl.maximum(m_i, tl.max(qk_fp32, axis=1))
    alpha = tl.math.exp2((m_i - m_ij) * 1.44269504)
    p_fp32 = tl.math.exp2((qk_fp32 - m_ij[:, None]) * 1.44269504)
    l_i = l_i * alpha + tl.sum(p_fp32, axis=1)

    # 量化 P 到 FP8 用于第二个 GEMM
    p_max = tl.max(tl.abs(p_fp32))   # 简化版:全局 scale;FA3 用 block scale
    scale_p = p_max / 448.0
    p_fp8 = (p_fp32 / scale_p).to(tl.float8e4nv)

    v_fp8 = tl.load(v_ptr + ...)
    scale_v = tl.load(scale_v_ptr + ...)
    acc_fp32 = acc_fp32 * alpha[:, None]
    acc_fp32 = tl.dot(p_fp8, v_fp8, acc_fp32, out_dtype=tl.float32)
    acc_fp32 = acc_fp32 * (scale_p * scale_v)

FP8 attention 仅限 H100+

FP8 Tensor Core 是 Hopper SM90 引入的,A100(SM80)和 L40S(SM89)都不支持。在 A100 上跑会自动 fallback 到 fp16,没有任何收益。

8.10 Warp Specialization 版本对比

FA2 和 FA3 的最大差异不在算法,而在 执行模型——FA3 引入了 producer-consumer 异步流水。

8.10.1 FA2:数据并行(单一执行流)

text
每个 warp 都做相同的事:
  Time →
  warp 0: [load K] → [compute QK] → [softmax] → [load V] → [compute PV]
  warp 1: [load K] → [compute QK] → [softmax] → [load V] → [compute PV]
  warp 2: [load K] → [compute QK] → [softmax] → [load V] → [compute PV]
  warp 3: [load K] → [compute QK] → [softmax] → [load V] → [compute PV]

问题:每个 warp 在 load 期间 Tensor Core 空转,在 compute 期间 LSU 空转。H100 测得这种模式利用率约 35%(FA3 论文)。

8.10.2 FA3:Warp Specialization(异步流水)

text
Producer warps (1 warp group = 4 warps):
  Time →
  [TMA load K_0] → [TMA load V_0] → [TMA load K_1] → [TMA load V_1] → ...
  (专门发射 TMA load 指令,几乎不消耗 Tensor Core)

Consumer warps (1~2 warp groups):
  Time →
  [wait K_0] → [WGMMA QK] → [softmax_0] → [wait V_0] → [WGMMA PV] → ...
  (专门做 WGMMA + softmax,从 SRAM 拿数据)

两组 warp 通过 mbarrier 同步,pipeline 全程并行!

效率提升来自三点:

  1. TMA 异步:producer 发了 TMA load 指令后立即返回,Tensor Memory Accelerator 硬件单元后台搬数据
  2. WGMMA 异步:consumer 发了 wgmma.mma_async 后立即返回,Tensor Core 后台计算
  3. 寄存器再分配:Hopper 的 setmaxnreg 指令让 producer 用少量寄存器(只发指令),consumer 拿到 90%+ 寄存器用于累加

8.10.3 Ping-pong 调度

FA3 在 consumer 内部还引入 ping-pong:双 buffer 让 GEMM 和 softmax 并行:

text
SRAM 双 buffer:
  Buffer A: K_0, V_0
  Buffer B: K_1, V_1

Consumer warpgroup 1: GEMM on A → softmax on A → GEMM on A2
Consumer warpgroup 2:              GEMM on B → softmax on B → GEMM on B2
                       ↑           ↑
                       两组 consumer 错开半个周期,互相隐藏 latency

关键洞察:H100 的 special function unit (做 exp) 只有 3.9 TFLOPS,而 Tensor Core 有 989 TFLOPS(256× 差距)。Ping-pong 让两个 consumer 错开,让 GEMM 和 softmax 真正并行——一个 warpgroup 在算 GEMM 时,另一个在算 softmax。

8.10.4 吞吐量对比数据

H100 SXM5 实测(来自 FA3 论文 + Together AI 博客):

版本seq=8192 d=128 fp16 fwd利用率bwd
FA2 (CUDA)~350 TFLOPS35%~250 TFLOPS
FA3 (CUDA)~750 TFLOPS75%~530 TFLOPS
FA3 FP8~1300 TFLOPS (1.3 PFLOPS)65% (FP8 peak)
Triton 06-fused-attention (v2 风格)~480 TFLOPS48%~310 TFLOPS
Triton + warp_specialize=True~620 TFLOPS62%

Triton 中启用 warp specialization

python
# 在 attention inner loop 加 warp_specialize=True
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=True):
    ...

Triton 编译器会自动尝试拆分 producer / consumer warps。但要拿到 FA3 全部收益(ping-pong 等)还需要 Triton 3.4+,并配合 TensorDescriptor(TMA)。当前(2026 年 5 月)Triton 的 warp_specialize 已能拿到 ~85% FA3 CUDA 性能,相比 FA2 已有 1.3~1.5× 提升。

8.10.5 Triton 中的实现差异

维度Triton FA2 风格Triton FA3 风格
加载 K/Vtl.load(k_ptrs, ...)desc_k.load([off, 0])(TensorDescriptor / TMA)
主循环for start_n in range(...)for start_n in tl.range(..., warp_specialize=True)
Tensor Core 指令mma.sync (Ampere) / wgmma(Hopper)wgmma (async)
同步隐式(每条指令同步)显式 mbarrier(编译器生成)
适用 GPUA100 / H100 / B200仅 H100+(需要 SM90 TMA)
代码改动少量主循环重写 + TensorDescriptor 替换指针

实操建议:

  • A100 / 旧硬件:继续用 FA2 风格 + 手写指针
  • H100 + Triton 3.4+:尝试 warp_specialize=True,性能能拿到 FA3 CUDA 85%
  • B200:必须用 TensorDescriptor + warp_specialize,因为 SM100 的 TMA tile execution 要求

本章小结

  • 朴素 attention 受限于 O(N²) 的 HBM 访问与显存,长序列下 GPU 算力被白白浪费。
  • FlashAttention 通过分块 (tiling) + 在线 softmax + 重算 (recomputation),把 HBM 访问降到 O(N),显存降到 O(N)
  • 在线 softmax 的核心是:新块到来时用 α = exp(m_old - m_new) 修正旧的累加器 (l, Õ),循环结束后 O = Õ / l
  • Triton 实现把 Q 块全程驻留 SRAM (共享内存),K/V 在内循环滑动,所有中间张量 S, P 都不落地 HBM。
  • Backward 用 recomputation:forward 只存 O, m, l 三个 O(N) 张量,backward 时按 block 重算 P,显存仍 O(N);Triton 实现是 preprocess + 两段 inner loop 的双 kernel 结构。
  • FP8 attention 通过 block quantization + incoherent processing,在 H100 上达到 1.3 PFLOPS(fp16 的 1.6×),同时精度损失可控(< 0.5% vs fp16)。
  • FA3 的 warp specialization 把 producer(TMA load)和 consumer(WGMMA + softmax)拆成异步流水,H100 上从 FA2 的 35% 利用率提升到 75%;Triton 通过 warp_specialize=True 能拿到约 85% 的 FA3 CUDA 性能。
  • 性能上 Triton 实现已接近甚至反超手写 CUDA FA2,并且代码量约为 CUDA 的 1/4,是研究迭代和算子定制的首选。

配套代码

本章所有代码片段对应 examples/04_flash_attention.py,建议在 GPU 机器上完整跑一遍 test_correctness()benchmark(),亲手观察长序列下的加速比。


思考题

  1. 关于 fp32 累加:本章核函数里 m_il_iacc 全部用 fp32 dtype,但 Q/K/V 是 fp16。如果把 acc 改成 fp16,会出什么问题?请结合 N_CTX = 8192 这种长序列估算最大累加误差。

  2. 关于 alpha 修正:假设主循环跑到第 5 个 K/V 块,前 4 块的 max 都是 -2.0,第 5 块出现了 max = 10.0。求此时 alpha = exp(m_old - m_new) 的具体数值,并说明这个数值在 acc 修正中的物理含义。

  3. 关于 v1 vs v2:FlashAttention v1 把 outer loop 放在 K/V、inner loop 放在 Q;v2 反过来。试着画出两种调度下,Q、K、V 分块在 SRAM 中的进出次数,并解释为什么 v2 在长序列下快 ~2×。


下一章我们将系统总结 Triton 开发中的常用模式、性能优化清单、调试技巧与社区资源——把前八章学到的散点经验汇成一张可对照的速查表。

基于 MIT 协议发布