Skip to content

12. Hopper / Blackwell 架构特性

把 H100/B200 的硬件红利吃干榨净:TMA、Warp Specialization、Async MMA、TMEM、FP4 —— 让你的 Triton kernel 跑出与 cuBLAS / FlashAttention-3 同级别的数字。

前 11 章我们假设 GPU 是一个"会算矩阵的盒子",调好 BLOCK / num_warps / num_stages 就能把性能榨出来。这套思路在 Ampere(A100/RTX 30/40)上行得通;到了 Hopper(H100、H200、GH200)和 Blackwell(B100/B200/GB200),Tensor Core 的算力增长远快于 SM 调度器和访存路径 —— 想跑到 80% 利用率,必须显式利用 TMA、Warp Specialization、Tensor Memory 这些新硬件。本章带你从原理到代码,吃透 Triton 在这两代架构上提供的所有钥匙。

本章内容概览

  • 12.1 架构演进与性能跃迁
  • 12.2 TMA:硬件加速的张量搬运
  • 12.3 Warp Specialization:生产者/消费者范式
  • 12.4 Multi-stage Pipelining 深度调优
  • 12.5 FlashAttention v3 的 Hopper 三板斧
  • 12.6 Blackwell 前瞻:TMEM、tcgen05 与 Gluon
  • 12.7 同一个 matmul:Ampere vs Hopper 写法对比
  • 12.8 本章小结与思考题

12.1 为什么需要关注硬件架构演进

12.1.1 三代算力对比

下表是单卡 dense Tensor Core 峰值(不含稀疏;BF16/FP8/FP4 分别取代表性数据):

GPU架构FP16/BF16FP8FP4HBM 带宽SMEM/SM
A100 80GBAmpere SM80312 TFLOPS2.04 TB/s192 KB
H100 SXM5Hopper SM90989 TFLOPS1979 TFLOPS3.35 TB/s228 KB
H200 SXM5Hopper SM90989 TFLOPS1979 TFLOPS4.80 TB/s228 KB
B200Blackwell SM1002250 TFLOPS4500 TFLOPS9000 TFLOPS8.00 TB/s228 KB
GB200 (per GPU)Blackwell SM1002500 TFLOPS5000 TFLOPS10000 TFLOPS8.00 TB/s228 KB

数据来源:NVIDIA H100/H200 Tensor Core GPU Datasheet、NVIDIA Blackwell B200 Datasheet (primeline-solutions.com)。

12.1.2 喂饱 Tensor Core 的难度指数级上升

算力增长了 3×(A100→H100)→ 2.3×(H100→B200),但 SMEM 几乎没变,HBM 带宽只增长了 1.6× 和 2.4×。这意味着:

  • 算/访存比(FLOPs per byte)从 ~150 跳到 ~280 再跳到 ~560
  • 每一代都要求 kernel 更激进地复用数据、更深地流水、更早地异步发起搬运
  • 老 Ampere 风格的 tl.load → tl.dot → tl.store 直来直去 + num_stages=3 双缓冲,在 H100 上只能跑到 35-50%(FlashAttention-2 仅 35%)

硬件鸿沟

H100 不再是"更大更快的 A100"。它要求你显式分离搬运 warp 与计算 warp,否则 WGMMA 单元会大量空转等数据。

12.1.3 Triton 抽象的两条路径

Triton 在新架构上提供了两层 API,复杂度递增:

层次API 入口适用场景
高层(主线 Triton DSL)@triton.jit + tl.make_tensor_descriptor + autotune num_consumer_groups80% 的 GEMM / FA / Norm 场景
低层(Gluon experimental)@gluon.jit + 显式 allocate_tensor_memory / tcgen05_mma / mbarrier极端性能、Blackwell TMEM 精细控制

本章主要讲高层路径,12.6 节介绍 Gluon。


12.2 TMA:硬件加速的张量搬运

12.2.1 原理:从地址计算解放线程

Tensor Memory Accelerator(TMA)是 Hopper 引入、Blackwell 增强的专用硬件单元,负责 GMEM ↔ SMEM 的 1D-5D 张量异步搬运。

对比 Ampere 的 cp.async(LDGSTS):

维度Ampere cp.asyncHopper TMA
地址生成每个线程自己算 stride / offset / mask单线程创建 descriptor,硬件负责全部计算
越界处理软件 predication硬件自动 padding(零填充 / 常数)
同步原语cp.async.commit_group + wait_groupmbarrier::arrive::expect_tx + mbarrier::try_wait
Multicast不支持同一份数据可同时投递到 cluster 内多 SM 的 SMEM
寄存器占用高(每线程都要算地址)极低(只一个线程发起)

底层 PTX 指令示例(H100 FP8 GEMM tile load):

text
cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
    [%smem_dst], [%tma_desc, {%coord_x, %coord_y}], [%mbar];

来源:PyTorch — Deep Dive on the Hopper TMA Unit for FP8 GEMMsNVIDIA Hopper Architecture In-Depth

12.2.2 Triton API:tl.make_tensor_descriptor

这是 Triton 3.2 起脱离 _experimental 命名空间的正式 API,主线推荐使用。

最小示例 —— 对矩阵做 in-place 绝对值:

python
import torch
import triton
import triton.language as tl
from typing import Optional

@triton.jit
def inplace_abs(in_out_ptr, M, N,
                M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
    # 创建 TMA descriptor,shape 与 strides 必须为运行时张量
    # block_shape 必须是 constexpr,决定每次 load/store 的 tile 大小
    desc = tl.make_tensor_descriptor(
        in_out_ptr,
        shape=[M, N],                    # 全局张量形状
        strides=[N, 1],                  # 行主序:最末维必须 contiguous
        block_shape=[M_BLOCK, N_BLOCK],  # 单次搬运的 tile 大小
    )
    moffset = tl.program_id(0) * M_BLOCK
    noffset = tl.program_id(1) * N_BLOCK
    # 一行代码完成「按坐标加载 tile」—— TMA 引擎搞定地址计算
    value = desc.load([moffset, noffset])
    desc.store([moffset, noffset], tl.abs(value))

# TMA descriptor 需要一块全局内存放映射表,Triton 要求用户注册 allocator
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
    return torch.empty(size, device="cuda", dtype=torch.int8)

triton.set_allocator(alloc_fn)

M, N = 256, 256
x = torch.randn(M, N, device="cuda")
M_BLOCK, N_BLOCK = 32, 32
grid = (M // M_BLOCK, N // N_BLOCK)
inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)

API 文档:triton.language.make_tensor_descriptor

12.2.3 三条硬约束(容易踩坑)

必须满足,否则 kernel 报错或 fallback 到慢路径

  1. base 指针 16-byte 对齐:从 PyTorch 拿的 torch.empty / torch.randn 默认满足;但 view / narrow 后可能不再对齐。
  2. shape 维度仅支持 2D-5D:1D 张量请 reshape 成 [1, N],6D 以上必须拆分。
  3. strides 约束:最末维必须 contiguous(stride == 1),其余维 stride 须是 16 字节的倍数(即 element_size 整除后是 8 的倍数)。

12.2.4 性能数据

PyTorch 官方在 H100 SXM5 上对比 FP8 GEMM(M=128, N=K=4096):

实现GMEM 吞吐备注
cuBLAS FP81.55 TB/sbaseline
Triton + TMA1.45 TB/s接近 cuBLAS
Triton 手动 cp.async0.95 TB/sTMA 优势明显

数据来源:PyTorch — Deep Dive on the Hopper TMA Unit

何时不需要 TMA

  • A100/V100/T4:硬件不支持,descriptor API 会 fallback 到 cp.async反而徒增 descriptor 创建开销
  • tile 极小(< 4KB):TMA 创建开销摊不平,老 tl.load 更快
  • 不规则索引(gather/scatter):TMA 只支持规则张量切片

12.3 Warp Specialization:生产者/消费者范式

12.3.1 为什么 Hopper 必须 WS

H100 每个 SM 有 8 个 warp scheduler(Ampere 是 4 个),加上新增的 warpgroup-wide 指令(wgmmasetmaxnreg),允许:

  • 将一个 CTA 内的 warps 划分成 partition(warp group),每组承担专门角色
  • 不同 partition 间通过 mbarrier 异步同步,代价远低于 __syncthreads()
  • 寄存器在 warpgroup 间动态再分配:producer 用极少寄存器(只发 TMA 指令),让出给 consumer(需要存 accumulator 与 K-tile)

典型的两 partition 分工:

┌──────────────────────────┐  ┌──────────────────────────┐
│ Producer Warp Group      │  │ Consumer Warp Group(s)   │
│ (1 warp, ~24 regs)       │  │ (4-8 warps, ~232 regs)   │
│                          │  │                          │
│  for tile_k in K:        │  │  for tile_k in K:        │
│    desc.load(K)→smem[k]  │  │    wait(smem[k])         │
│    mbarrier.arrive()     │  │    acc = wgmma(Q,K,acc)  │
│                          │  │    mbarrier.release()    │
└──────────────────────────┘  └──────────────────────────┘
        TMA 引擎在跑                Tensor Core 在跑
              \__________共享 SMEM 多缓冲池__________/

12.3.2 Triton 启用方式:autotune 配置法

这是 3.2+ 稳定路径,已在 PyTorch 2.6 发布。只需要在 triton.Config 里加两个参数:

python
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256,
             "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
            num_stages=2,
            num_warps=4,
            num_consumer_groups=2,      # ★ 启用 WS,并设定 consumer 组数
            num_buffers_warp_spec=3,    # ★ producer→consumer 之间 SMEM 多缓冲数
        ),
        # 也保留非-WS 配置作为 fallback
        triton.Config(
            {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256,
             "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
            num_stages=4, num_warps=8,
        ),
    ],
    key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(...):
    ...

参数解释:

参数含义取值建议
num_consumer_groupsconsumer warp group 数量;0 表示不启用 WS0(关闭)/ 1(单 consumer)/ 2(ping-pong)
num_buffers_warp_spec多缓冲深度,与 num_stages 类似2-4,受 SMEM 容量限制

当前仅支持一个 producer

Triton 3.4 仍只支持单 producer warp group,因此没有 num_producer_groups 参数。

Hopper 上的已知 bug

pointer-based loop + tl.dot + warp_specialize=True 当前在 Hopper(sm90)上会触发 NVGPUWarpSpecialization pass 的 assertion failure(GH #9728)。临时绕过:

  • 改用 tl.make_tensor_descriptor 描述符路径,或
  • 不显式传 warp_specialize=True,让 autotune 通过 num_consumer_groups 隐式启用

12.3.3 编译器 pass 链揭秘

Triton 把这套用户层简单的 autotune 参数翻译成机器码经过 5 个关键 pass:

顺序Pass 名称干了什么
1DataPartition把主循环切成 producer 任务和 consumer 任务,识别哪些 load 由 TMA producer 完成
2WSCodePartition插入 ProducerAcquireOp / ProducerCommitOp / ConsumerWaitOp / ConsumerReleaseOp;标识"channel"(即 producer→consumer 数据流)
3WSLowerToken降到 NVGPU dialect,把 op 翻译成 mbarrier arrive/wait PTX
4Memory Planner按 live range 分配 SMEM/TMEM channel;多缓冲 buffer 复用
5Software Pipelining (SWP)在 WS 之上做 modulo scheduling,跨迭代重排独立 op

详解可参考:PyTorch — Warp Specialization in Triton: Design and RoadmapIan Barber — How does Triton do Warp Spec?

学术参考:Tawa 与 Twill
  • Tawa(CGO'26)实现了自动 WS 上层 IR 重写,在 Hopper FP16 FA 上做到 96% 的 FA3 性能,对 Triton baseline(FA2 风格)有 1.21× 加速。
  • Twill(arXiv 2512.18134)通过最优 modulo scheduling 在 backward attention 上接近/超过 FA3。

这些工作的成果将逐步进入 Triton 主线。

12.3.4 实测增益

PyTorch 官方在 H100 上测量了 WS 对几个 production kernel 的影响:

Kernel关闭 WS启用 WS加速
FlashAttention forward (FP16)baseline+12%1.12×
FP8 row-wise GEMMbaseline+15%1.15×
Persistent matmul (TMA)baseline+10%1.10×

来源:PyTorch — Enabling Warp Specialization


12.4 Multi-stage Pipelining 深度调优

12.4.1 num_stages 的本质

num_stages 控制软件流水深度 —— 编译器会自动把 K 维主循环展开成多缓冲,让"加载下一 tile"与"计算当前 tile"重叠:

text
num_stages = 2 (经典 double buffer)
┌─────┬─────┬─────┬─────┬─────┐
│ load│ load│ load│ load│  -- │
│  -- │ comp│ comp│ comp│ comp│
└─────┴─────┴─────┴─────┴─────┘
  k=0   k=1   k=2   k=3   k=4

num_stages = 4 (深层流水)
┌─────┬─────┬─────┬─────┬─────┐
│load0│load1│load2│load3│load4│
│  -- │  -- │  -- │comp0│comp1│
└─────┴─────┴─────┴─────┴─────┘
  延迟 3 个 stage 后才开始算,更好地隐藏 GMEM 长延迟

12.4.2 不同架构的最佳值

经验调参表(GEMM 类负载)

架构推荐 num_stages原因
A100 (Ampere SM80)3-4cp.async 延迟 ~600 cycle,3-4 stage 够隐藏
H100 (Hopper SM90, 普通 path)4-5HBM3 延迟更高,需要更深
H100 + TMA + WGMMA4-6TMA 异步 + WGMMA 异步可吃深流水
B200 (Blackwell SM100)5-7HBM3e 带宽大但延迟也大,TMEM 提供更多 buffer

12.4.3 SMEM 容量约束

num_stages 不能无脑加大。每多一级,SMEM 占用 +(BLOCK_M + BLOCK_N) * BLOCK_K * dtype_size

例如 BF16 GEMM、BLOCK_M=128, BLOCK_N=256, BLOCK_K=64

  • 单 stage 用 (128 + 256) * 64 * 2 = 48 KB
  • H100 每 SM 228 KB SMEM
  • 理论上限:228 / 48 ≈ 4.75,所以 num_stages=4 是上限
  • 想用 num_stages=5,必须把 BLOCK_K 降到 48

编译时报错

SMEM 超限时 Triton 会报:

triton.runtime.errors.OutOfResources:
out of resource: shared memory, Required: 245760, Hardware limit: 232448

此时要么减 num_stages,要么减 BLOCK_K

12.4.4 Triton 与 JAX Mosaic 的哲学差异

这是个值得知道的对比 —— 同样在 Hopper 上做 pipelining,两个 DSL 走完全不同的路

维度TritonJAX Mosaic GPU (Pallas)
流水控制隐式(编译器 SWP pass 自动做)显式(用户手写 plgpu.emit_pipeline
用户负担num_stages 即可需要分别声明 producer/consumer 函数体
优势入门门槛低;小 kernel 调优快大 kernel 控制力强;MegaKernel 易写
劣势复杂场景编译器决策可能不优上手成本高

来源:Mosaic GPU Pipelining — JAX docs


12.5 FlashAttention v3 的 Hopper 三板斧

FlashAttention-3(Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao;NeurIPS 2024 spotlight)专为 Hopper 设计,在 H100 上把 FP16 利用率从 FA2 的 35% 拉到 75%1.5-2× 加速),FP8 路径下达到 1.2 PFLOPs/s

论文 / 项目:arXiv 2407.08608tridao.me/publications/flash3

12.5.1 三大技术贡献

① Producer-Consumer Warp Specialization(inter-warpgroup overlap)

  • WG0(producer,1 个 warpgroup = 4 warps):只发 TMA load,把 Q/K/V tile 写入 SMEM 多缓冲池
  • WG1, WG2(consumer):每组执行 FA2 风格的内循环(QK^T → softmax → P·V)
  • 两个 consumer warpgroup ping-pong 错相运行:当 WG1 在算 softmax 时,WG2 在跑 WGMMA;反之亦然 —— 让 softmax 的 exp/div 完全被 WGMMA 隐藏

寄存器再分配是关键

text
producer warpgroup: setmaxnreg.dec.sync.aligned.u32 24    # 释放寄存器
consumer warpgroup: setmaxnreg.inc.sync.aligned.u32 232  # 接收寄存器

producer 不需要存 K-tile 与 acc,让出寄存器给 consumer 装更大 BLOCK_N。

② Intra-warpgroup overlap(迭代内重排)

同一个 consumer warpgroup 内部,S = QK^T 的 WGMMA 与上一 block 的 P·V WGMMA 交错,softmax 的标量运算夹在两个 GEMM 之间运行 —— exp/div 完全隐藏。

③ FP8 attention + incoherent processing

  • Block quantization:每个 Q/K/V tile 独立 scale(不是 per-tensor)
  • Incoherent processing:在量化前对 Q/K 乘以随机正交矩阵(Hadamard 变换),打散 outlier,降低 quantization noise —— 思路来自 QuIP#
  • In-kernel transpose:FP8 WGMMA 要求 K 矩阵转置布局,但 K 在 SMEM 中是 row-major,需要在 kernel 内做布局转换 —— 论文重点工程难点

12.5.2 Triton 中如何复现 FA3

主线 Triton 教程仍是 FA2

官方 06-fused-attention.pytl.dot + autotune,没有显式 WS。要想接近 FA3 性能:

  1. 必加num_consumer_groups=2 + num_buffers_warp_spec=3
  2. 必加:用 tl.make_tensor_descriptor 替代手动 tl.load,走 TMA 路径
  3. 可选:FP8 路径需要手写 incoherent processing 与 in-kernel transpose
  4. 可选:让 SWP scheduler 做 modulo scheduling(启用 WS 后自动生效)

Tawa 团队报告:当前 Triton FA-Hopper baseline(autotune + WS)能达到 FA3 的 ~80% 性能;剩余 20% 主要差在 in-kernel transpose 与 epilogue 调度。

学术追踪:Tawa — Automatic WS (arXiv 2510.14719)


12.6 Blackwell 前瞻:TMEM、tcgen05 与 Gluon

12.6.1 硬件新增点

Blackwell(SM100)在 Hopper 基础上新增:

硬件单元容量 / 特性
TMEM(Tensor Memory)每 SM 256 KB,128 行 × 512 列 × 32-bit cell;专门服务 Tensor Core
5th-gen Tensor Core支持 FP4 / FP6 / FP8 / FP16 / BF16 / TF32 / FP64;FP4 dense 9 PFLOPs/GPU
tcgen05.mma单线程指令(不再是 warp-synchronous WGMMA);操作数来自 SMEM 或 TMEM
Decompression Engine800 GB/s 硬件解压(gzip / snappy / Deflate / LZ4)
2nd-gen Transformer Engine集成 micro-scaling(NVFP4 / MXFP8)

来源:NVIDIA Blackwell B200 DatasheetMicrobenchmarking Blackwell (arXiv 2512.02189)

12.6.2 TMEM 编程模型

TMEM 不能用普通 ld.shared 访问,必须用专用 PTX:

指令用途
tcgen05.alloc分配 TMEM 列(必须 power-of-2,[32, 512])
tcgen05.dealloc释放 —— kernel 退出前必须显式调用,否则 leak
tcgen05.ldTMEM → registers
tcgen05.stregisters → TMEM
tcgen05.cpSMEM → TMEM(异步)
tcgen05.mma操作数从 SMEM/TMEM,累加到 TMEM
tcgen05.commit完成 mbarrier signal

Warp 访问限制

  • 每个 warp 只能访问 TMEM 的 32 行(基于 warp ID)
  • 要覆盖全 128 行,必须整 warpgroup 协作
  • TMEM load/store 在 Gluon 中需要 4 或 8 个 warps

12.6.3 高层 vs 低层:两种使用方式

方式 A:主线 Triton DSL(推荐)

继续写 tl.dot + tl.make_tensor_descriptor —— 编译器在 SM100 上自动 lower 到 tcgen05.mma,TMEM 分配也由编译器内部完成。用户层代码与 Hopper 一致

方式 B:Gluon DSL(极端性能)

显式控制 TMEM 分配、layout、tcgen05 指令调度:

python
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.blackwell import (
    TensorMemoryLayout, allocate_tensor_memory,
    tcgen05_mma, tcgen05_commit, fence_async_shared,
    tma, mbarrier,
)

@gluon.jit
def blackwell_matmul_kernel(a_desc, b_desc, d_desc, M, N, K, ...):
    # 显式 TMEM layout:col_stride 影响访存模式
    acc_tmem_layout: gl.constexpr = TensorMemoryLayout(
        tmem_block.value,
        col_stride=32 // d_desc.dtype.primitive_bitwidth,
    )
    # 显式分配 TMEM —— 必须配对 dealloc
    acc_tmem = allocate_tensor_memory(d_desc.dtype, [M, N], acc_tmem_layout)

    # ...TMA load 到 SMEM...
    tcgen05_mma(a_smem, b_smem, acc_tmem)   # 单线程发 MMA
    tcgen05_commit(barrier)                  # 用 mbarrier 通知完成

完整教程:Triton — The 5th Generation TensorCore、源码 tutorials/gluon/06-tcgen05.py

12.6.4 当前稳定性

特性Triton 3.x 主线Gluon
Blackwell tcgen05.mma 透明使用(tl.dot自动 lower
TMEM 手动分配不支持已支持
FP8 GEMM稳定稳定
FP6 GEMM实验实验
FP4 GEMM实验,部分 layout实验
tcgen05.cp(SMEM→TMEM 异步 copy)编译器内部使用已暴露

选择建议

  • 新项目、要在 B200 上跑:先用主线 DSL,让编译器自动用 TMEM
  • 性能差距 >10% 且能定位到 TMEM 调度问题:再考虑下沉到 Gluon
  • 学习目的:先看 Gluon 教程理解硬件模型,再回来理解高层 DSL 做了什么

12.7 同一个 matmul:Ampere vs Hopper 写法对比

让我们用一个 BF16 GEMM 实例对照两种写法 —— 关键差异已用注释标出

12.7.1 Ampere 风格(A100 上最优)

python
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64,
                       "GROUP_M": 8},
                      num_stages=3, num_warps=8),
        triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32,
                       "GROUP_M": 8},
                      num_stages=4, num_warps=4),
    ],
    key=["M", "N", "K"],
)
@triton.jit
def matmul_ampere(a_ptr, b_ptr, c_ptr, M, N, K,
                  stride_am, stride_ak, stride_bk, stride_bn,
                  stride_cm, stride_cn,
                  BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
                  BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
    pid = tl.program_id(0)
    # ... 标准的 swizzle 2D pid 计算 ...
    pid_m, pid_n = swizzle_pid(pid, M, N, BLOCK_M, BLOCK_N, GROUP_M)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # ★ Ampere 风格:手动算指针,每次 tl.load 都要重算 mask
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        # ★ 手动 mask,处理 K 维边界
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
        acc += tl.dot(a, b)              # mma.sync.aligned.m16n8k16
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mask)

12.7.2 Hopper 风格(H100 上最优)

python
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64,
                       "GROUP_M": 8},
                      num_stages=4, num_warps=4,
                      num_consumer_groups=2,        # ★ 启用 WS
                      num_buffers_warp_spec=3),     # ★ producer-consumer 缓冲
    ],
    key=["M", "N", "K"],
)
@triton.jit
def matmul_hopper(a_ptr, b_ptr, c_ptr, M, N, K,
                  BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
                  BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
    pid = tl.program_id(0)
    pid_m, pid_n = swizzle_pid(pid, M, N, BLOCK_M, BLOCK_N, GROUP_M)

    # ★ Hopper 风格:用 TMA descriptor 替代手动指针计算
    a_desc = tl.make_tensor_descriptor(
        a_ptr, shape=[M, K], strides=[K, 1],
        block_shape=[BLOCK_M, BLOCK_K],
    )
    b_desc = tl.make_tensor_descriptor(
        b_ptr, shape=[K, N], strides=[N, 1],
        block_shape=[BLOCK_K, BLOCK_N],
    )
    c_desc = tl.make_tensor_descriptor(
        c_ptr, shape=[M, N], strides=[N, 1],
        block_shape=[BLOCK_M, BLOCK_N],
    )

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        # ★ 一行完成 TMA load —— 无需手动算 ptr 与 mask(硬件 padding)
        a = a_desc.load([pid_m * BLOCK_M, k * BLOCK_K])
        b = b_desc.load([k * BLOCK_K, pid_n * BLOCK_N])
        acc = tl.dot(a, b, acc)           # 自动 lower 到 wgmma.async.m64n256k16

    # ★ TMA store
    c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(tl.bfloat16))

12.7.3 差异总结

维度Ampere 风格Hopper 风格
内存访问手动指针 + maskmake_tensor_descriptor + 坐标
边界处理软件 mask(每次 load)硬件 padding(descriptor 时设置)
异步搬运cp.async(编译器自动)TMA cp.async.bulk.tensor
流水深度num_stages=3-4num_stages=4-6
WSnum_consumer_groups=2
MMA 指令mma.sync.aligned.m16n8k16wgmma.async.m64n256k16
寄存器压力高(每线程算地址)低(descriptor 单线程)
典型利用率70-80%(A100)60-75%(H100,需 WS 才达 75%+)

工程实践

  1. 写一份代码跑两个架构:Triton autotune 会自动挑最优 config —— 在 Ampere config 不带 num_consumer_groups,在 Hopper config 加上
  2. triton.runtime.driver.active.get_current_target() 判断硬件,分支选择 kernel 入口
  3. 持久 matmul(persistent kernel) 是 H100 上的另一关键技巧 —— 把 grid 缩到 SM 数,每个 program 循环处理多个 tile,进一步降低 launch 开销与提升 L2 复用

完整 persistent matmul 实现:Triton 教程 09-persistent-matmul


12.8 本章小结

本章把 Hopper 与 Blackwell 上 Triton 提供的关键新特性串成一条线 —— 从硬件演进的算/访存比挑战,到 TMA 解放线程的搬运范式,到 Warp Specialization 的生产者/消费者分工,再到 FlashAttention v3 的三板斧,最后一瞥 Blackwell 的 TMEM 与 Gluon 低层 API。

关键 takeaways

  1. TMA 不是可选项:H100+ 上跑 GEMM/Attention 不用 make_tensor_descriptor,相当于把 30% 性能扔在桌上
  2. WS 的两参数公式num_consumer_groups=2, num_buffers_warp_spec=3 是 90% Hopper GEMM/FA 的起点配置
  3. num_stages 要跟着架构走:A100 用 3,H100 用 4-5,B200 用 5-7
  4. FA3 三板斧:producer-consumer WS + 迭代内 GEMM/softmax 交错 + FP8 incoherent processing
  5. Blackwell 高低双 API:日常用主线 DSL 让编译器透明 lower 到 tcgen05,极致性能下沉 Gluon

Triton 3.x 功能稳定性速查

特性状态推荐使用
tl.make_tensor_descriptor✅ 稳定(3.2+)
num_consumer_groups autotune✅ 稳定(3.2+)
warp_specialize=True 显式参数⚠️ pointer-based 路径有 bug用描述符路径
tcgen05.mma 透明使用✅ 编译器自动
Gluon TMEM API🧪 实验仅极端性能场景
FP4 / FP6 GEMM🧪 PoC暂不推荐生产

思考题

思考题 1:TMA descriptor 的隐藏开销

tl.make_tensor_descriptor 在 kernel 内每次调用都会创建一个 descriptor。如果你的 kernel 在主循环外创建 descriptor 后又在循环内用同一 descriptor 100 次 load,与"每次循环都重新创建 descriptor 再 load"相比,性能差距来自哪里?请用 NSight Compute 设计一个最小实验验证你的猜想。

思考题 2:WS 不总是更快

题目:用 BLOCK_M=64, BLOCK_N=64, BLOCK_K=32num_warps=2 的小 GEMM kernel,分别测 num_consumer_groups=0num_consumer_groups=2 在 H100 上的性能。如果发现 WS 反而更慢,可能的原因有哪些?(提示:思考 warpgroup 同步开销、寄存器分配、producer 利用率三个角度)

思考题 3:Blackwell TMEM 容量规划

B200 每 SM 有 256 KB TMEM。假设你要写一个 BF16 GEMM kernel,目标 BLOCK_M=128, BLOCK_N=256,且想 double-buffer 累加器以重叠 epilogue 与下一 K-tile 的计算。请计算:

  1. 单个 FP32 累加器 tile 占多少 TMEM 列?
  2. Double buffering 后需要多少列?
  3. 剩余的 TMEM 列能否再容纳 FP16 A/B operand(如果想把 A 放到 TMEM 提速)?
思考题 3 参考答案
  1. 128 × 256 FP32 累加器 = 128 × 256 × 4 bytes。TMEM 每列是 128 rows × 4 bytes = 512 bytes,所以单 tile 需要 256 × 4 / 4 = 256 列 —— 但 TMEM 只有 512 列,已用一半
  2. Double buffer 需要 512 列 —— 正好占满 TMEM
  3. 没有剩余列;想再放 A operand 必须把累加器降到 single buffer,或减小 BLOCK_N 到 128(这样累加器只用 256 列双缓冲 + 还剩 256 列给 A)。这就是为什么 Blackwell GEMM 比 Hopper 更需要精细 tile 规划。

下一章我们将走进真正的工业代码 —— vLLM、SGLang、PyTorch Inductor、Unsloth 这些每天服务百亿请求的系统,看它们如何把 Triton 用到极致。

基于 MIT 协议发布