12. Hopper / Blackwell 架构特性
把 H100/B200 的硬件红利吃干榨净:TMA、Warp Specialization、Async MMA、TMEM、FP4 —— 让你的 Triton kernel 跑出与 cuBLAS / FlashAttention-3 同级别的数字。
前 11 章我们假设 GPU 是一个"会算矩阵的盒子",调好 BLOCK / num_warps / num_stages 就能把性能榨出来。这套思路在 Ampere(A100/RTX 30/40)上行得通;到了 Hopper(H100、H200、GH200)和 Blackwell(B100/B200/GB200),Tensor Core 的算力增长远快于 SM 调度器和访存路径 —— 想跑到 80% 利用率,必须显式利用 TMA、Warp Specialization、Tensor Memory 这些新硬件。本章带你从原理到代码,吃透 Triton 在这两代架构上提供的所有钥匙。
本章内容概览
- 12.1 架构演进与性能跃迁
- 12.2 TMA:硬件加速的张量搬运
- 12.3 Warp Specialization:生产者/消费者范式
- 12.4 Multi-stage Pipelining 深度调优
- 12.5 FlashAttention v3 的 Hopper 三板斧
- 12.6 Blackwell 前瞻:TMEM、tcgen05 与 Gluon
- 12.7 同一个 matmul:Ampere vs Hopper 写法对比
- 12.8 本章小结与思考题
12.1 为什么需要关注硬件架构演进
12.1.1 三代算力对比
下表是单卡 dense Tensor Core 峰值(不含稀疏;BF16/FP8/FP4 分别取代表性数据):
| GPU | 架构 | FP16/BF16 | FP8 | FP4 | HBM 带宽 | SMEM/SM |
|---|---|---|---|---|---|---|
| A100 80GB | Ampere SM80 | 312 TFLOPS | — | — | 2.04 TB/s | 192 KB |
| H100 SXM5 | Hopper SM90 | 989 TFLOPS | 1979 TFLOPS | — | 3.35 TB/s | 228 KB |
| H200 SXM5 | Hopper SM90 | 989 TFLOPS | 1979 TFLOPS | — | 4.80 TB/s | 228 KB |
| B200 | Blackwell SM100 | 2250 TFLOPS | 4500 TFLOPS | 9000 TFLOPS | 8.00 TB/s | 228 KB |
| GB200 (per GPU) | Blackwell SM100 | 2500 TFLOPS | 5000 TFLOPS | 10000 TFLOPS | 8.00 TB/s | 228 KB |
数据来源:NVIDIA H100/H200 Tensor Core GPU Datasheet、NVIDIA Blackwell B200 Datasheet (primeline-solutions.com)。
12.1.2 喂饱 Tensor Core 的难度指数级上升
算力增长了 3×(A100→H100)→ 2.3×(H100→B200),但 SMEM 几乎没变,HBM 带宽只增长了 1.6× 和 2.4×。这意味着:
- 算/访存比(FLOPs per byte)从 ~150 跳到 ~280 再跳到 ~560
- 每一代都要求 kernel 更激进地复用数据、更深地流水、更早地异步发起搬运
- 老 Ampere 风格的
tl.load → tl.dot → tl.store直来直去 +num_stages=3双缓冲,在 H100 上只能跑到 35-50%(FlashAttention-2 仅 35%)
硬件鸿沟
H100 不再是"更大更快的 A100"。它要求你显式分离搬运 warp 与计算 warp,否则 WGMMA 单元会大量空转等数据。
12.1.3 Triton 抽象的两条路径
Triton 在新架构上提供了两层 API,复杂度递增:
| 层次 | API 入口 | 适用场景 |
|---|---|---|
| 高层(主线 Triton DSL) | @triton.jit + tl.make_tensor_descriptor + autotune num_consumer_groups | 80% 的 GEMM / FA / Norm 场景 |
| 低层(Gluon experimental) | @gluon.jit + 显式 allocate_tensor_memory / tcgen05_mma / mbarrier | 极端性能、Blackwell TMEM 精细控制 |
本章主要讲高层路径,12.6 节介绍 Gluon。
12.2 TMA:硬件加速的张量搬运
12.2.1 原理:从地址计算解放线程
Tensor Memory Accelerator(TMA)是 Hopper 引入、Blackwell 增强的专用硬件单元,负责 GMEM ↔ SMEM 的 1D-5D 张量异步搬运。
对比 Ampere 的 cp.async(LDGSTS):
| 维度 | Ampere cp.async | Hopper TMA |
|---|---|---|
| 地址生成 | 每个线程自己算 stride / offset / mask | 单线程创建 descriptor,硬件负责全部计算 |
| 越界处理 | 软件 predication | 硬件自动 padding(零填充 / 常数) |
| 同步原语 | cp.async.commit_group + wait_group | mbarrier::arrive::expect_tx + mbarrier::try_wait |
| Multicast | 不支持 | 同一份数据可同时投递到 cluster 内多 SM 的 SMEM |
| 寄存器占用 | 高(每线程都要算地址) | 极低(只一个线程发起) |
底层 PTX 指令示例(H100 FP8 GEMM tile load):
cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
[%smem_dst], [%tma_desc, {%coord_x, %coord_y}], [%mbar];来源:PyTorch — Deep Dive on the Hopper TMA Unit for FP8 GEMMs、NVIDIA Hopper Architecture In-Depth。
12.2.2 Triton API:tl.make_tensor_descriptor
这是 Triton 3.2 起脱离 _experimental 命名空间的正式 API,主线推荐使用。
最小示例 —— 对矩阵做 in-place 绝对值:
import torch
import triton
import triton.language as tl
from typing import Optional
@triton.jit
def inplace_abs(in_out_ptr, M, N,
M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
# 创建 TMA descriptor,shape 与 strides 必须为运行时张量
# block_shape 必须是 constexpr,决定每次 load/store 的 tile 大小
desc = tl.make_tensor_descriptor(
in_out_ptr,
shape=[M, N], # 全局张量形状
strides=[N, 1], # 行主序:最末维必须 contiguous
block_shape=[M_BLOCK, N_BLOCK], # 单次搬运的 tile 大小
)
moffset = tl.program_id(0) * M_BLOCK
noffset = tl.program_id(1) * N_BLOCK
# 一行代码完成「按坐标加载 tile」—— TMA 引擎搞定地址计算
value = desc.load([moffset, noffset])
desc.store([moffset, noffset], tl.abs(value))
# TMA descriptor 需要一块全局内存放映射表,Triton 要求用户注册 allocator
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
M, N = 256, 256
x = torch.randn(M, N, device="cuda")
M_BLOCK, N_BLOCK = 32, 32
grid = (M // M_BLOCK, N // N_BLOCK)
inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)12.2.3 三条硬约束(容易踩坑)
必须满足,否则 kernel 报错或 fallback 到慢路径
- base 指针 16-byte 对齐:从 PyTorch 拿的
torch.empty / torch.randn默认满足;但 view / narrow 后可能不再对齐。 - shape 维度仅支持 2D-5D:1D 张量请 reshape 成
[1, N],6D 以上必须拆分。 - strides 约束:最末维必须 contiguous(stride == 1),其余维 stride 须是 16 字节的倍数(即 element_size 整除后是 8 的倍数)。
12.2.4 性能数据
PyTorch 官方在 H100 SXM5 上对比 FP8 GEMM(M=128, N=K=4096):
| 实现 | GMEM 吞吐 | 备注 |
|---|---|---|
| cuBLAS FP8 | 1.55 TB/s | baseline |
| Triton + TMA | 1.45 TB/s | 接近 cuBLAS |
Triton 手动 cp.async | 0.95 TB/s | TMA 优势明显 |
何时不需要 TMA
- A100/V100/T4:硬件不支持,descriptor API 会 fallback 到
cp.async,反而徒增 descriptor 创建开销 - tile 极小(< 4KB):TMA 创建开销摊不平,老
tl.load更快 - 不规则索引(gather/scatter):TMA 只支持规则张量切片
12.3 Warp Specialization:生产者/消费者范式
12.3.1 为什么 Hopper 必须 WS
H100 每个 SM 有 8 个 warp scheduler(Ampere 是 4 个),加上新增的 warpgroup-wide 指令(wgmma、setmaxnreg),允许:
- 将一个 CTA 内的 warps 划分成 partition(warp group),每组承担专门角色
- 不同 partition 间通过
mbarrier异步同步,代价远低于__syncthreads() - 寄存器在 warpgroup 间动态再分配:producer 用极少寄存器(只发 TMA 指令),让出给 consumer(需要存 accumulator 与 K-tile)
典型的两 partition 分工:
┌──────────────────────────┐ ┌──────────────────────────┐
│ Producer Warp Group │ │ Consumer Warp Group(s) │
│ (1 warp, ~24 regs) │ │ (4-8 warps, ~232 regs) │
│ │ │ │
│ for tile_k in K: │ │ for tile_k in K: │
│ desc.load(K)→smem[k] │ │ wait(smem[k]) │
│ mbarrier.arrive() │ │ acc = wgmma(Q,K,acc) │
│ │ │ mbarrier.release() │
└──────────────────────────┘ └──────────────────────────┘
TMA 引擎在跑 Tensor Core 在跑
\__________共享 SMEM 多缓冲池__________/12.3.2 Triton 启用方式:autotune 配置法
这是 3.2+ 稳定路径,已在 PyTorch 2.6 发布。只需要在 triton.Config 里加两个参数:
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
num_stages=2,
num_warps=4,
num_consumer_groups=2, # ★ 启用 WS,并设定 consumer 组数
num_buffers_warp_spec=3, # ★ producer→consumer 之间 SMEM 多缓冲数
),
# 也保留非-WS 配置作为 fallback
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8},
num_stages=4, num_warps=8,
),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(...):
...参数解释:
| 参数 | 含义 | 取值建议 |
|---|---|---|
num_consumer_groups | consumer warp group 数量;0 表示不启用 WS | 0(关闭)/ 1(单 consumer)/ 2(ping-pong) |
num_buffers_warp_spec | 多缓冲深度,与 num_stages 类似 | 2-4,受 SMEM 容量限制 |
当前仅支持一个 producer
Triton 3.4 仍只支持单 producer warp group,因此没有 num_producer_groups 参数。
Hopper 上的已知 bug
pointer-based loop + tl.dot + warp_specialize=True 当前在 Hopper(sm90)上会触发 NVGPUWarpSpecialization pass 的 assertion failure(GH #9728)。临时绕过:
- 改用
tl.make_tensor_descriptor描述符路径,或 - 不显式传
warp_specialize=True,让 autotune 通过num_consumer_groups隐式启用
12.3.3 编译器 pass 链揭秘
Triton 把这套用户层简单的 autotune 参数翻译成机器码经过 5 个关键 pass:
| 顺序 | Pass 名称 | 干了什么 |
|---|---|---|
| 1 | DataPartition | 把主循环切成 producer 任务和 consumer 任务,识别哪些 load 由 TMA producer 完成 |
| 2 | WSCodePartition | 插入 ProducerAcquireOp / ProducerCommitOp / ConsumerWaitOp / ConsumerReleaseOp;标识"channel"(即 producer→consumer 数据流) |
| 3 | WSLowerToken | 降到 NVGPU dialect,把 op 翻译成 mbarrier arrive/wait PTX |
| 4 | Memory Planner | 按 live range 分配 SMEM/TMEM channel;多缓冲 buffer 复用 |
| 5 | Software Pipelining (SWP) | 在 WS 之上做 modulo scheduling,跨迭代重排独立 op |
详解可参考:PyTorch — Warp Specialization in Triton: Design and Roadmap、Ian Barber — How does Triton do Warp Spec?。
学术参考:Tawa 与 Twill
- Tawa(CGO'26)实现了自动 WS 上层 IR 重写,在 Hopper FP16 FA 上做到 96% 的 FA3 性能,对 Triton baseline(FA2 风格)有 1.21× 加速。
- Twill(arXiv 2512.18134)通过最优 modulo scheduling 在 backward attention 上接近/超过 FA3。
这些工作的成果将逐步进入 Triton 主线。
12.3.4 实测增益
PyTorch 官方在 H100 上测量了 WS 对几个 production kernel 的影响:
| Kernel | 关闭 WS | 启用 WS | 加速 |
|---|---|---|---|
| FlashAttention forward (FP16) | baseline | +12% | 1.12× |
| FP8 row-wise GEMM | baseline | +15% | 1.15× |
| Persistent matmul (TMA) | baseline | +10% | 1.10× |
12.4 Multi-stage Pipelining 深度调优
12.4.1 num_stages 的本质
num_stages 控制软件流水深度 —— 编译器会自动把 K 维主循环展开成多缓冲,让"加载下一 tile"与"计算当前 tile"重叠:
num_stages = 2 (经典 double buffer)
┌─────┬─────┬─────┬─────┬─────┐
│ load│ load│ load│ load│ -- │
│ -- │ comp│ comp│ comp│ comp│
└─────┴─────┴─────┴─────┴─────┘
k=0 k=1 k=2 k=3 k=4
num_stages = 4 (深层流水)
┌─────┬─────┬─────┬─────┬─────┐
│load0│load1│load2│load3│load4│
│ -- │ -- │ -- │comp0│comp1│
└─────┴─────┴─────┴─────┴─────┘
延迟 3 个 stage 后才开始算,更好地隐藏 GMEM 长延迟12.4.2 不同架构的最佳值
经验调参表(GEMM 类负载)
| 架构 | 推荐 num_stages | 原因 |
|---|---|---|
| A100 (Ampere SM80) | 3-4 | cp.async 延迟 ~600 cycle,3-4 stage 够隐藏 |
| H100 (Hopper SM90, 普通 path) | 4-5 | HBM3 延迟更高,需要更深 |
| H100 + TMA + WGMMA | 4-6 | TMA 异步 + WGMMA 异步可吃深流水 |
| B200 (Blackwell SM100) | 5-7 | HBM3e 带宽大但延迟也大,TMEM 提供更多 buffer |
12.4.3 SMEM 容量约束
num_stages 不能无脑加大。每多一级,SMEM 占用 +(BLOCK_M + BLOCK_N) * BLOCK_K * dtype_size。
例如 BF16 GEMM、BLOCK_M=128, BLOCK_N=256, BLOCK_K=64:
- 单 stage 用
(128 + 256) * 64 * 2 = 48 KB - H100 每 SM 228 KB SMEM
- 理论上限:
228 / 48 ≈ 4.75,所以num_stages=4是上限 - 想用
num_stages=5,必须把BLOCK_K降到 48
编译时报错
SMEM 超限时 Triton 会报:
triton.runtime.errors.OutOfResources:
out of resource: shared memory, Required: 245760, Hardware limit: 232448此时要么减 num_stages,要么减 BLOCK_K。
12.4.4 Triton 与 JAX Mosaic 的哲学差异
这是个值得知道的对比 —— 同样在 Hopper 上做 pipelining,两个 DSL 走完全不同的路:
| 维度 | Triton | JAX Mosaic GPU (Pallas) |
|---|---|---|
| 流水控制 | 隐式(编译器 SWP pass 自动做) | 显式(用户手写 plgpu.emit_pipeline) |
| 用户负担 | 调 num_stages 即可 | 需要分别声明 producer/consumer 函数体 |
| 优势 | 入门门槛低;小 kernel 调优快 | 大 kernel 控制力强;MegaKernel 易写 |
| 劣势 | 复杂场景编译器决策可能不优 | 上手成本高 |
12.5 FlashAttention v3 的 Hopper 三板斧
FlashAttention-3(Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao;NeurIPS 2024 spotlight)专为 Hopper 设计,在 H100 上把 FP16 利用率从 FA2 的 35% 拉到 75%(1.5-2× 加速),FP8 路径下达到 1.2 PFLOPs/s。
12.5.1 三大技术贡献
① Producer-Consumer Warp Specialization(inter-warpgroup overlap)
- WG0(producer,1 个 warpgroup = 4 warps):只发 TMA load,把 Q/K/V tile 写入 SMEM 多缓冲池
- WG1, WG2(consumer):每组执行 FA2 风格的内循环(QK^T → softmax → P·V)
- 两个 consumer warpgroup ping-pong 错相运行:当 WG1 在算 softmax 时,WG2 在跑 WGMMA;反之亦然 —— 让 softmax 的 exp/div 完全被 WGMMA 隐藏
寄存器再分配是关键
producer warpgroup: setmaxnreg.dec.sync.aligned.u32 24 # 释放寄存器
consumer warpgroup: setmaxnreg.inc.sync.aligned.u32 232 # 接收寄存器producer 不需要存 K-tile 与 acc,让出寄存器给 consumer 装更大 BLOCK_N。
② Intra-warpgroup overlap(迭代内重排)
同一个 consumer warpgroup 内部,把 S = QK^T 的 WGMMA 与上一 block 的 P·V WGMMA 交错,softmax 的标量运算夹在两个 GEMM 之间运行 —— exp/div 完全隐藏。
③ FP8 attention + incoherent processing
- Block quantization:每个 Q/K/V tile 独立 scale(不是 per-tensor)
- Incoherent processing:在量化前对 Q/K 乘以随机正交矩阵(Hadamard 变换),打散 outlier,降低 quantization noise —— 思路来自 QuIP#
- In-kernel transpose:FP8 WGMMA 要求 K 矩阵转置布局,但 K 在 SMEM 中是 row-major,需要在 kernel 内做布局转换 —— 论文重点工程难点
12.5.2 Triton 中如何复现 FA3
主线 Triton 教程仍是 FA2
官方 06-fused-attention.py 用 tl.dot + autotune,没有显式 WS。要想接近 FA3 性能:
- 必加:
num_consumer_groups=2+num_buffers_warp_spec=3 - 必加:用
tl.make_tensor_descriptor替代手动tl.load,走 TMA 路径 - 可选:FP8 路径需要手写 incoherent processing 与 in-kernel transpose
- 可选:让 SWP scheduler 做 modulo scheduling(启用 WS 后自动生效)
Tawa 团队报告:当前 Triton FA-Hopper baseline(autotune + WS)能达到 FA3 的 ~80% 性能;剩余 20% 主要差在 in-kernel transpose 与 epilogue 调度。
12.6 Blackwell 前瞻:TMEM、tcgen05 与 Gluon
12.6.1 硬件新增点
Blackwell(SM100)在 Hopper 基础上新增:
| 硬件单元 | 容量 / 特性 |
|---|---|
| TMEM(Tensor Memory) | 每 SM 256 KB,128 行 × 512 列 × 32-bit cell;专门服务 Tensor Core |
| 5th-gen Tensor Core | 支持 FP4 / FP6 / FP8 / FP16 / BF16 / TF32 / FP64;FP4 dense 9 PFLOPs/GPU |
tcgen05.mma | 单线程指令(不再是 warp-synchronous WGMMA);操作数来自 SMEM 或 TMEM |
| Decompression Engine | 800 GB/s 硬件解压(gzip / snappy / Deflate / LZ4) |
| 2nd-gen Transformer Engine | 集成 micro-scaling(NVFP4 / MXFP8) |
来源:NVIDIA Blackwell B200 Datasheet、Microbenchmarking Blackwell (arXiv 2512.02189)。
12.6.2 TMEM 编程模型
TMEM 不能用普通 ld.shared 访问,必须用专用 PTX:
| 指令 | 用途 |
|---|---|
tcgen05.alloc | 分配 TMEM 列(必须 power-of-2,[32, 512]) |
tcgen05.dealloc | 释放 —— kernel 退出前必须显式调用,否则 leak |
tcgen05.ld | TMEM → registers |
tcgen05.st | registers → TMEM |
tcgen05.cp | SMEM → TMEM(异步) |
tcgen05.mma | 操作数从 SMEM/TMEM,累加到 TMEM |
tcgen05.commit | 完成 mbarrier signal |
Warp 访问限制
- 每个 warp 只能访问 TMEM 的 32 行(基于 warp ID)
- 要覆盖全 128 行,必须整 warpgroup 协作
- TMEM load/store 在 Gluon 中需要 4 或 8 个 warps
12.6.3 高层 vs 低层:两种使用方式
方式 A:主线 Triton DSL(推荐)
继续写 tl.dot + tl.make_tensor_descriptor —— 编译器在 SM100 上自动 lower 到 tcgen05.mma,TMEM 分配也由编译器内部完成。用户层代码与 Hopper 一致。
方式 B:Gluon DSL(极端性能)
显式控制 TMEM 分配、layout、tcgen05 指令调度:
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout, allocate_tensor_memory,
tcgen05_mma, tcgen05_commit, fence_async_shared,
tma, mbarrier,
)
@gluon.jit
def blackwell_matmul_kernel(a_desc, b_desc, d_desc, M, N, K, ...):
# 显式 TMEM layout:col_stride 影响访存模式
acc_tmem_layout: gl.constexpr = TensorMemoryLayout(
tmem_block.value,
col_stride=32 // d_desc.dtype.primitive_bitwidth,
)
# 显式分配 TMEM —— 必须配对 dealloc
acc_tmem = allocate_tensor_memory(d_desc.dtype, [M, N], acc_tmem_layout)
# ...TMA load 到 SMEM...
tcgen05_mma(a_smem, b_smem, acc_tmem) # 单线程发 MMA
tcgen05_commit(barrier) # 用 mbarrier 通知完成完整教程:Triton — The 5th Generation TensorCore、源码
tutorials/gluon/06-tcgen05.py。
12.6.4 当前稳定性
| 特性 | Triton 3.x 主线 | Gluon |
|---|---|---|
Blackwell tcgen05.mma 透明使用(tl.dot) | 自动 lower | — |
| TMEM 手动分配 | 不支持 | 已支持 |
| FP8 GEMM | 稳定 | 稳定 |
| FP6 GEMM | 实验 | 实验 |
| FP4 GEMM | 实验,部分 layout | 实验 |
tcgen05.cp(SMEM→TMEM 异步 copy) | 编译器内部使用 | 已暴露 |
选择建议
- 新项目、要在 B200 上跑:先用主线 DSL,让编译器自动用 TMEM
- 性能差距 >10% 且能定位到 TMEM 调度问题:再考虑下沉到 Gluon
- 学习目的:先看 Gluon 教程理解硬件模型,再回来理解高层 DSL 做了什么
12.7 同一个 matmul:Ampere vs Hopper 写法对比
让我们用一个 BF16 GEMM 实例对照两种写法 —— 关键差异已用注释标出。
12.7.1 Ampere 风格(A100 上最优)
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64,
"GROUP_M": 8},
num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32,
"GROUP_M": 8},
num_stages=4, num_warps=4),
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_ampere(a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
pid = tl.program_id(0)
# ... 标准的 swizzle 2D pid 计算 ...
pid_m, pid_n = swizzle_pid(pid, M, N, BLOCK_M, BLOCK_N, GROUP_M)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# ★ Ampere 风格:手动算指针,每次 tl.load 都要重算 mask
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# ★ 手动 mask,处理 K 维边界
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
acc += tl.dot(a, b) # mma.sync.aligned.m16n8k16
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(tl.bfloat16), mask=mask)12.7.2 Hopper 风格(H100 上最优)
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64,
"GROUP_M": 8},
num_stages=4, num_warps=4,
num_consumer_groups=2, # ★ 启用 WS
num_buffers_warp_spec=3), # ★ producer-consumer 缓冲
],
key=["M", "N", "K"],
)
@triton.jit
def matmul_hopper(a_ptr, b_ptr, c_ptr, M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr):
pid = tl.program_id(0)
pid_m, pid_n = swizzle_pid(pid, M, N, BLOCK_M, BLOCK_N, GROUP_M)
# ★ Hopper 风格:用 TMA descriptor 替代手动指针计算
a_desc = tl.make_tensor_descriptor(
a_ptr, shape=[M, K], strides=[K, 1],
block_shape=[BLOCK_M, BLOCK_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr, shape=[K, N], strides=[N, 1],
block_shape=[BLOCK_K, BLOCK_N],
)
c_desc = tl.make_tensor_descriptor(
c_ptr, shape=[M, N], strides=[N, 1],
block_shape=[BLOCK_M, BLOCK_N],
)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# ★ 一行完成 TMA load —— 无需手动算 ptr 与 mask(硬件 padding)
a = a_desc.load([pid_m * BLOCK_M, k * BLOCK_K])
b = b_desc.load([k * BLOCK_K, pid_n * BLOCK_N])
acc = tl.dot(a, b, acc) # 自动 lower 到 wgmma.async.m64n256k16
# ★ TMA store
c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], acc.to(tl.bfloat16))12.7.3 差异总结
| 维度 | Ampere 风格 | Hopper 风格 |
|---|---|---|
| 内存访问 | 手动指针 + mask | make_tensor_descriptor + 坐标 |
| 边界处理 | 软件 mask(每次 load) | 硬件 padding(descriptor 时设置) |
| 异步搬运 | cp.async(编译器自动) | TMA cp.async.bulk.tensor |
| 流水深度 | num_stages=3-4 | num_stages=4-6 |
| WS | 无 | num_consumer_groups=2 |
| MMA 指令 | mma.sync.aligned.m16n8k16 | wgmma.async.m64n256k16 |
| 寄存器压力 | 高(每线程算地址) | 低(descriptor 单线程) |
| 典型利用率 | 70-80%(A100) | 60-75%(H100,需 WS 才达 75%+) |
工程实践
- 写一份代码跑两个架构:Triton autotune 会自动挑最优 config —— 在 Ampere config 不带
num_consumer_groups,在 Hopper config 加上 - 用
triton.runtime.driver.active.get_current_target()判断硬件,分支选择 kernel 入口 - 持久 matmul(persistent kernel) 是 H100 上的另一关键技巧 —— 把 grid 缩到 SM 数,每个 program 循环处理多个 tile,进一步降低 launch 开销与提升 L2 复用
完整 persistent matmul 实现:Triton 教程 09-persistent-matmul。
12.8 本章小结
本章把 Hopper 与 Blackwell 上 Triton 提供的关键新特性串成一条线 —— 从硬件演进的算/访存比挑战,到 TMA 解放线程的搬运范式,到 Warp Specialization 的生产者/消费者分工,再到 FlashAttention v3 的三板斧,最后一瞥 Blackwell 的 TMEM 与 Gluon 低层 API。
关键 takeaways
- TMA 不是可选项:H100+ 上跑 GEMM/Attention 不用
make_tensor_descriptor,相当于把 30% 性能扔在桌上 - WS 的两参数公式:
num_consumer_groups=2, num_buffers_warp_spec=3是 90% Hopper GEMM/FA 的起点配置 num_stages要跟着架构走:A100 用 3,H100 用 4-5,B200 用 5-7- FA3 三板斧:producer-consumer WS + 迭代内 GEMM/softmax 交错 + FP8 incoherent processing
- Blackwell 高低双 API:日常用主线 DSL 让编译器透明 lower 到
tcgen05,极致性能下沉 Gluon
Triton 3.x 功能稳定性速查
| 特性 | 状态 | 推荐使用 |
|---|---|---|
tl.make_tensor_descriptor | ✅ 稳定(3.2+) | 是 |
num_consumer_groups autotune | ✅ 稳定(3.2+) | 是 |
warp_specialize=True 显式参数 | ⚠️ pointer-based 路径有 bug | 用描述符路径 |
tcgen05.mma 透明使用 | ✅ 编译器自动 | 是 |
| Gluon TMEM API | 🧪 实验 | 仅极端性能场景 |
| FP4 / FP6 GEMM | 🧪 PoC | 暂不推荐生产 |
思考题
思考题 1:TMA descriptor 的隐藏开销
tl.make_tensor_descriptor 在 kernel 内每次调用都会创建一个 descriptor。如果你的 kernel 在主循环外创建 descriptor 后又在循环内用同一 descriptor 100 次 load,与"每次循环都重新创建 descriptor 再 load"相比,性能差距来自哪里?请用 NSight Compute 设计一个最小实验验证你的猜想。
思考题 2:WS 不总是更快
题目:用 BLOCK_M=64, BLOCK_N=64, BLOCK_K=32、num_warps=2 的小 GEMM kernel,分别测 num_consumer_groups=0 与 num_consumer_groups=2 在 H100 上的性能。如果发现 WS 反而更慢,可能的原因有哪些?(提示:思考 warpgroup 同步开销、寄存器分配、producer 利用率三个角度)
思考题 3:Blackwell TMEM 容量规划
B200 每 SM 有 256 KB TMEM。假设你要写一个 BF16 GEMM kernel,目标 BLOCK_M=128, BLOCK_N=256,且想 double-buffer 累加器以重叠 epilogue 与下一 K-tile 的计算。请计算:
- 单个 FP32 累加器 tile 占多少 TMEM 列?
- Double buffering 后需要多少列?
- 剩余的 TMEM 列能否再容纳 FP16 A/B operand(如果想把 A 放到 TMEM 提速)?
思考题 3 参考答案
128 × 256FP32 累加器 =128 × 256 × 4bytes。TMEM 每列是 128 rows × 4 bytes = 512 bytes,所以单 tile 需要256 × 4 / 4 = 256列 —— 但 TMEM 只有 512 列,已用一半。- Double buffer 需要 512 列 —— 正好占满 TMEM。
- 没有剩余列;想再放 A operand 必须把累加器降到 single buffer,或减小
BLOCK_N到 128(这样累加器只用 256 列双缓冲 + 还剩 256 列给 A)。这就是为什么 Blackwell GEMM 比 Hopper 更需要精细 tile 规划。
下一章我们将走进真正的工业代码 —— vLLM、SGLang、PyTorch Inductor、Unsloth 这些每天服务百亿请求的系统,看它们如何把 Triton 用到极致。