6. 自动调优
同一个核函数在不同 GPU、不同输入 shape 下,最优配置可能截然不同。手工试参数是无穷的体力活——
@triton.autotune让编译器替你跑这个搜索。
本章会从装饰器的基本用法讲起,逐参数解析 triton.Config 的含义,结合实战代码(examples/03_matmul.py)展示完整调优空间设计,最后给出生产环境的最佳实践与避坑指南。
6.1 为什么需要自动调优
写过 GPU 核函数的人都遇到过这个场景:
- 同一个 matmul,
BLOCK_SIZE_M=128在 A100 上飞快,换到 H100 就慢了。 - 同一个 softmax,
num_warps=4在 batch=4 时最优,batch=64 就变成num_warps=16更好。 - 改了一下
BLOCK_SIZE_K=64 → 32,性能反而提升了 20%。
原因是:性能由架构特性 + 输入 shape + 寄存器/SRAM 预算三方面共同决定,组合爆炸。手工调一个 shape 要 30 分钟,10 个 shape × 3 个 GPU 就是 15 小时——根本调不过来。
Triton 的解法
让用户写一个"候选配置列表",把 BLOCK_SIZE、num_warps、num_stages 等所有可能的值都列出来。Triton 在核函数首次以某个 shape 调用时跑一遍 benchmark,自动挑出最快的那个,并缓存到内存里。后续调用零开销。
6.2 @triton.autotune 装饰器详解
最基本的用法只有两个参数:
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_SIZE': 256}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_SIZE': 512}, num_warps=8, num_stages=3),
],
key=['N'], # 仅当 N 变化时才重新跑 autotune
)
@triton.jit
def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < N
x = tl.load(x_ptr + offs, mask=mask)
y = tl.load(y_ptr + offs, mask=mask)
tl.store(out_ptr + offs, x + y, mask=mask)6.2.1 核心参数
| 参数 | 含义 |
|---|---|
configs | triton.Config 列表,每项描述一种候选配置 |
key | 一组参数名;当核函数调用时这些参数值变了,autotune 会重新评估所有配置 |
prune_configs_by | 可选的剪枝函数(见 6.5 节) |
reset_to_zero | 调优时会多次执行核函数;用此参数指定哪些输出 tensor 在每次执行前归零 |
restore_value | 类似 reset_to_zero,但保存输入值而非清零(适合 in-place 更新) |
pre_hook / post_hook | 自定义钩子,覆盖默认的 reset / restore 行为 |
warmup | 预热毫秒数(已 deprecated,新版本会自动管理) |
6.2.2 triton.Config 的参数
triton.Config(
kwargs={'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, ...}, # meta-parameters
num_warps=8, # 每程序实例用多少 warp(NVIDIA 上每 warp = 32 线程)
num_stages=3, # 软件流水线深度
num_ctas=1, # cluster 内 CTA 数(Hopper+,高级用法)
maxnreg=None, # 限制每线程寄存器上限,控制 spill
)每个参数的影响:
BLOCK_SIZE_*(分块尺寸):决定 SRAM 占用与 Tensor Core 调度效率。大分块复用好但占资源多。num_warps:单程序实例的并行度。更多 warp → 更多并行;但每个 warp 可用寄存器减少,可能引发 spill。num_stages:软件流水线深度,类似 CUDA 的cp.async双/多缓冲。增大能更好地隐藏 DRAM 延迟,代价是 SRAM 翻倍占用。- NVIDIA Ampere/Hopper 常用 2~5
- AMD ROCm 推荐迥然不同:单 GEMM → 0;FlashAttention 这类双 GEMM 融合 → 1;详见 ROCm 官方文档。
跨硬件核函数必须分平台 config
NVIDIA 和 AMD 的 num_stages 语义不一致,简单的 "config 列表" 在两种硬件上往往一边性能差。生产代码应该写 get_cuda_autotune_config() 和 get_hip_autotune_config() 两套函数,按 triton.runtime.driver.active.get_current_target() 动态选择。
6.3 工作原理与缓存
理解 autotune 的执行流程能帮你避免很多坑:
首次调用 matmul_kernel(M=1024, N=1024, K=1024)
↓
key = (1024, 1024, 1024)
↓
查缓存:未命中
↓
依次跑所有 config(benchmark 取中位耗时)
↓
选最快的 config,缓存到内存
↓
执行核函数
↓
返回
第二次调用 matmul_kernel(M=1024, N=1024, K=1024)
↓
key = (1024, 1024, 1024)
↓
命中缓存 → 直接用最佳 config 执行,零调优开销6.3.1 持久化编译缓存
autotune 的最佳 config 选择是进程内缓存——重启 Python 进程后会重新跑。但 Triton 还有一层编译产物缓存(PTX / SASS)可以跨进程持久化:
# 让 Triton 把编译产物写到稳定目录
export TRITON_CACHE_DIR=/workspace/.triton_cache下次启动同样的核函数 + 同样的 config 时,Triton 不需要重新编译——但 autotune 仍会重新跑一遍 benchmark 来选 config。
6.3.2 调试用环境变量
# 在每次 autotune 完成后打印选中的 config
export TRITON_PRINT_AUTOTUNING=1输出大致长这样:
Triton autotuning for function matmul_kernel
finished after 0.44s;
best config selected:
BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64,
GROUP_SIZE_M: 8, num_warps: 8, num_stages: 3把这条日志重定向到文件,再人工浏览,能极大帮助你理解哪些 shape 偏好哪些 config。
6.4 实战:为 matmul 设计调优空间
完整代码见 examples/03_matmul.py。这里聚焦 autotune 部分:
def get_autotune_configs():
return [
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8},
num_stages=3, num_warps=8),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
]
@triton.autotune(
configs=get_autotune_configs(),
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(...):
...6.4.1 设计要点
覆盖典型 shape 的甜点:6 个 config 覆盖了"大分块 + 少 warp"与"小分块 + 多 warp"两种典型策略。
128 × 256 × 64偏好大 batch,64 × 128 × 32偏好瘦长矩阵。GROUP_SIZE_M=8统一:所有 config 都用8,因为这是 A100/H100 上经验上最佳的 L2 swizzling 粒度。如果你不确定,也可以把它放进搜索空间,但会让 config 总数翻几倍。key=['M', 'N', 'K']:仅当问题尺寸变化时才重新调优。这意味着如果你的应用每次都跑相同 shape,autotune 只有第一次有开销。配 fp8 还有 1 套 config(教程的完整版本会额外加 10 个
BLOCK_SIZE_K=128的 config)。原因是 fp8 单元素更小,可以用更大的 K tile。
6.4.2 候选数量的权衡
- 太少(< 5 个):可能覆盖不到最优解,性能上限受限。
- 太多(> 30 个):首次调优时间剧增,社区报告过单 FlashAttention shape 调优 82 分钟的案例。
- 经验值:matmul 6~16 个 config;FlashAttention 因参数多通常 20~40 个;vecadd 这类简单核函数 3~5 个足够。
6.5 高级特性:剪枝大调优空间
当 config 数量超过 20 个时,可以用 prune_configs_by 在 benchmark 之前先过滤掉一批:
def prune_invalid_configs(configs, named_args, **kwargs):
"""根据运行时参数过滤明显不可行的 config。"""
N_CTX = kwargs["N_CTX"]
STAGE = kwargs["STAGE"]
return [
conf for conf in configs
if conf.kwargs.get("BLOCK_M", 0) <= N_CTX
and (conf.kwargs.get("BLOCK_M", 0) >= conf.kwargs.get("BLOCK_N", 0)
or STAGE == 1)
]
@triton.autotune(
configs=large_config_list,
key=["N_CTX", "HEAD_DIM"],
prune_configs_by={'early_config_prune': prune_invalid_configs},
)
@triton.jit
def attn_fwd_kernel(...):
...prune_configs_by 还支持基于性能模型的剪枝:
def perf_model(N_CTX, HEAD_DIM, BLOCK_M, BLOCK_N, num_warps, **kwargs):
"""返回 config 的预测运行时间。"""
return N_CTX * N_CTX * HEAD_DIM / (BLOCK_M * BLOCK_N * num_warps)
@triton.autotune(
configs=large_config_list,
key=["N_CTX", "HEAD_DIM"],
prune_configs_by={
'perf_model': perf_model,
'top_k': 10, # 只 benchmark 预测前 10 名
},
)这种方式适合调优空间庞大的复杂核函数(如 FlashAttention 的 40+ config)。
6.6 调优策略与最佳实践
6.6.1 写出"好调优"的核函数
- 把所有可能影响性能的参数都设为
tl.constexpr:这样它们成为编译期常量,autotune 才能对每个组合编译独立版本。 - 不要在核函数里写死
num_warps或num_stages:让 autotune 自动处理。 - 每个 config 的
kwargs必须能让核函数跑通:别忘了添加static_assert检查不合法的组合。
6.6.2 部署前的预热
生产服务启动时,对所有典型 shape 都"预跑"一次,把 autotune 开销挪到部署阶段:
def warmup_all_shapes():
"""在服务启动时调用,把 autotune 移到部署阶段。"""
typical_shapes = [
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
]
for M, N, K in typical_shapes:
a = torch.randn(M, K, device='cuda', dtype=torch.float16)
b = torch.randn(K, N, device='cuda', dtype=torch.float16)
_ = matmul(a, b) # 触发 autotune
torch.cuda.synchronize()
print("Warmup done.")6.6.3 处理写入型核函数
如果核函数会写一个 tensor(典型如 in-place 更新、累加),autotune 多次执行会让结果累加 N 次,必须配 reset_to_zero:
@triton.autotune(
configs=[...],
key=['N'],
reset_to_zero=['output_ptr'], # 每次执行前把 output 清零
)
@triton.jit
def my_kernel(input_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < N
x = tl.load(input_ptr + offs, mask=mask)
# 原子累加,多次执行会重复加
tl.atomic_add(output_ptr, tl.sum(x))类似地,对于需要保留输入原值的核函数,用 restore_value 在每次执行后还原。
6.7 注意事项与坑
坑 1:调优时间失控
autotune 默认会 benchmark 所有 config,并且每个跑多次取中位。Config 数量 × 评估次数 = 总耗时,容易爆炸。社区有报告过 FlashAttention 单 shape 调优 82 分钟的案例(GitHub issue #9401)。
对策:
- 控制 config 数量在 10~20 个
- 用
prune_configs_by过滤无效组合 - 优先调最常用的几个 shape,其它 shape 用相近 shape 的最佳 config
坑 2:benchmark 测不准
罗切斯特大学论文《Characterizing Autotuning Costs in OpenAI's Triton》指出,Triton 的 autotuner 在估算"warmed-up 核函数运行时间"时用未预热样本估算,导致快速核函数的 benchmark 时间被严重低估——本来想测 100ms,实际可能只跑了不到 20ms。
对策:对性能敏感的算子,最终选定的 config 应该用 triton.testing.do_bench 在真实数据上独立 benchmark 复核一次,而不要完全信赖 autotune 的内部测量。
坑 3:写入型核函数结果错误
忘了配 reset_to_zero 时,autotune 多次累加导致输出值是真实值的 N 倍——但核函数本身没报错,只是数值错。
对策:所有写输出的核函数上 autotune 之前,先用单 config 验证正确性,再加 autotune;并且显式声明所有需要 reset 的指针。
坑 4:跨硬件配置混用
NVIDIA 和 AMD 的最优 num_stages 不一致,简单 reuse 一套 config 会让一种硬件性能差 30%+。
对策:用 is_cuda() / is_hip() 判断后返回不同的 config 列表。
6.8 编译开销量化分析
到此为止我们都把 autotune 当成"免费午餐"。它不是。每个 config 都要付出一笔可观的隐藏成本,搞清楚账本才能在生产环境理性使用。
6.8.1 一次 autotune 的开销分解
罗切斯特大学 2024 年的研究论文 Characterizing Autotuning Costs in OpenAI's Triton 把单次 autotune 的耗时拆成三部分:
| 阶段 | 占比 | 说明 |
|---|---|---|
| JIT 编译(每 config 一次) | 占总耗时 ~80% | 走完 triton.jit → TTIR → TTGIR → LLVM IR → PTX → SASS 全流程 |
| Kernel 执行 + benchmark | 占总耗时 ~10% | 多次 launch 取中位,含一次 warmup |
| L2 cache 清空 + 框架开销 | 占总耗时 ~50%(非 kernel/编译部分的大头) | 论文识别出的最大"隐式"成本 |
注意:因为编译和 L2 清空可以部分并行,三者加起来超过 100%。论文的 parallel-compilation 改造把编译阶段并行化后,整体 autotune 时间最高可降 7.9×。
社区报告的极端数据:
- FlashAttention 单 shape × 40 configs:曾耗时 82 分钟 / ~4631 秒(GitHub issue #9401)。
- matmul 6 configs:A100 上典型 0.4~1.5 秒,前 5 个 shape 共耗 5~10 秒。
- vLLM 在线服务首请求:未做预热时第一个 prompt 的 attention kernel 可触发数百毫秒到数秒的 autotune,p99 延迟尖刺明显。
6.8.2 估算公式
给定 N 个 config,单 config 编译时间 T_compile(典型 0.5~3 秒),单次 benchmark 时间 T_bench(典型 25~125 ms):
T_autotune ≈ N × (T_compile + T_bench)
≈ N × T_compile (编译时间通常远大于 benchmark)经验值:
| Config 数 | 简单 kernel(elementwise) | matmul | FlashAttention |
|---|---|---|---|
| 6 | ~1 s | ~3 s | ~10 s |
| 16 | ~3 s | ~10 s | ~40 s |
| 40 | ~8 s | ~25 s | ~120 s |
| 100+ | ~20 s | ~70 s | 可达 30~80 分钟 |
别把 autotune 留给首请求
在线推理服务最忌讳"用户的第一个请求触发 autotune"——p99 延迟会爆掉。Spheron 在 2026 部署指南里测得:未预热的 Docker 容器冷启动时,30~120 秒会被 Triton 编译吃掉。
6.8.3 生产部署的 Ahead-of-Time (AOT) 策略
针对生产环境的三种主流方案:
方案 1:服务启动时主动 warmup(最简单)
def warmup_all_shapes():
"""在容器启动时调用,把 autotune 开销挪到部署阶段。"""
typical_shapes = [
(1, 4096), # batch=1 推理
(8, 4096), # batch=8 推理
(32, 2048), # batch=32 训练
(1, 8192), # 长 context
]
for B, N in typical_shapes:
x = torch.randn(B, N, device='cuda', dtype=torch.float16)
_ = my_kernel(x)
torch.cuda.synchronize()
# Dockerfile / Kubernetes 启动钩子里调用
if __name__ == '__main__':
warmup_all_shapes()
serve_inference()方案 2:持久化 PTX/SASS 缓存(跨重启)
# Dockerfile
ENV TRITON_CACHE_DIR=/triton_cache
# k8s deployment.yaml 把 /triton_cache 挂载为 PVC下次启动后,编译阶段 命中 PTX 缓存(省 80% autotune 时间),但 autotune 的 best config 选择 仍是进程级缓存,依然要跑一遍 benchmark。
方案 3:IBM Dejavu / 离线 config 注入(最彻底)
IBM 在 Ray Summit 2024 开源的 triton-dejavu 把 autotune 选中的 config 序列化到 JSON,下次启动直接读取,完全跳过 benchmark:
# 第一次跑:发现并保存
from triton_dejavu import autotune # 替代 triton.autotune
@autotune(configs=[...], key=['M', 'N', 'K'],
dejavu_dir='/configs/matmul.json')
@triton.jit
def matmul_kernel(...): ...
# 部署时:JSON 已存在,直接命中
# {"(4096, 4096, 4096)": "BLOCK_M:128, BLOCK_N:256, num_warps:8, num_stages:3"}vLLM 团队报告:用 dejavu 把 首请求 autotune 开销从秒级降到零,同时保留 cross-GPU portability。
6.8.4 triton.runtime.autotuner 的缓存语义
理解缓存层级能帮你定位"为什么明明缓存了还在跑":
┌─────────────────────────────────────────────────┐
│ 1. Autotuner.cache (Python dict, 进程内) │ ← key tuple → best Config
│ 生命周期:Python 进程;重启即失效 │
├─────────────────────────────────────────────────┤
│ 2. JITFunction.cache (Python dict, 进程内) │ ← (config, signature) → CompiledKernel
│ 生命周期:Python 进程;编译产物的内存索引 │
├─────────────────────────────────────────────────┤
│ 3. TRITON_CACHE_DIR (磁盘) │ ← hash → .ttir/.ttgir/.ptx/.cubin
│ 生命周期:手动清理之前都在;可跨进程/容器 │
└─────────────────────────────────────────────────┘实测:缓存命中行为
import os, time, torch, triton
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
# 第 1 次:完整 autotune(编译 + benchmark)
t0 = time.time(); _ = matmul(a, b); print(f'1st: {time.time()-t0:.2f}s')
# Triton autotuning ... finished after 3.21s
# 第 2 次:进程内缓存命中
t0 = time.time(); _ = matmul(a, b); print(f'2nd: {time.time()-t0:.4f}s')
# 2nd: 0.0001s
# 重启 Python 进程后(TRITON_CACHE_DIR 已设置)
# PTX 缓存命中 → 跳过编译;但 autotuner 缓存丢失 → 重新 benchmark
# Triton autotuning ... finished after 0.45s (省了 80% 编译时间,但仍要 benchmark)6.9 搜索空间设计的数学建模
随手列 10 个 config 凭运气工作,列得不好性能直接打七折。系统化设计搜索空间需要从硬件 spec 反推合理的参数范围。
6.9.1 BLOCK_SIZE 与 occupancy 的关系
GPU occupancy(SM 上同时驻留的 warp 数 / 最大 warp 数)由三个 SM 资源同时决定,取最严的那个:
$$ \text{Occupancy} = \min!\left( \underbrace{\frac{\text{Regs}{\text{SM}}}{R \cdot 32 \cdot W}}{\text{寄存器限制}},; \underbrace{\frac{\text{SMEM}{\text{SM}}}{S}}{\text{共享内存限制}},; \underbrace{\text{MaxBlocks}{\text{SM}}}{\text{硬件常量}} \right) $$
其中 R = 每线程寄存器数,W = num_warps,S = 单 block 共享内存字节,Regs_SM、SMEM_SM 来自硬件 spec(A100:65536 regs、164 KB SMEM;H100:65536 regs、228 KB SMEM)。
反向推导 BLOCK_SIZE 上限(以 fp16 matmul 为例):
单 block 共享内存 S = 2 × (BM × BK + BK × BN) × 2 bytes # A、B 两个 tile,fp16
+ num_stages × 上述容量 # 多缓冲
要让 occupancy ≥ 4 blocks/SM(经验值),需要:
S ≤ SMEM_SM / 4
A100 (164 KB / 4 = 41 KB) 下取 num_stages=3:
3 × 4 × (BM·BK + BK·BN) ≤ 41 × 1024
BM·BK + BK·BN ≤ 3500
当 BK=32, BM=BN=128 时:32·128 + 32·128 = 8192 → 超
当 BK=32, BM=BN=64 时:32·64 + 32·64 = 4096 → 超但接近
当 BK=32, BM=64, BN=128 时:32·64+32·128 = 6144 → 仍超
⇒ A100 上 num_stages=3 + 128×128×32 已经把 occupancy 压到 ≤ 2 blocks/SM。这就是为什么官方 matmul tutorial 在 A100 上 BLOCK_M=BLOCK_N=128, BLOCK_K=32 不是越大越好——再大就 SMEM 溢出,编译器要么降级 stages 要么 spill 寄存器。
6.9.2 num_warps 与 tile 形状的约束
tl.dot 内部把 [BM, BN] 输出 tile 分给 num_warps 个 warp,每 warp 负责 [BM/warps_m, BN/warps_n] 的子块,由编译器自动选 warps_m × warps_n 的拆分。约束:
BM × BN必须被num_warps × MMA_tile_size整除。在 NVIDIA 上 Tensor Core 的最小 mma 形状是16×16(fp16)或16×8(fp8),所以: $$ BM, BN \in {16, 32, 64, 128, 256, \ldots} \quad\text{且}\quad \frac{BM \cdot BN}{256} \geq \text{num_warps} $$num_warps 上限:单 program 最多 32 warps(NVIDIA),即 1024 线程。
BM=BN=128→tile=16384→ num_warps ≤ 64,但实际 ≤ 16BM=BN=64→tile=4096→ num_warps ≤ 16,常用 4~8BM=BN=32→tile=1024→ num_warps ≤ 4
每 warp 的工作量经验值:单 warp 处理 ≥
32×32元素时 Tensor Core 利用率最高。
6.9.3 GROUP_SIZE_M 与 L2 命中率模型
L2 swizzling 通过让 grid 上相邻的 program 复用 A、B 矩阵的 tile 来提升 L2 命中率。建模:
对 [M, N, K] 的 matmul,分成 num_m × num_n 个输出 block。
朴素行主序:算前 row 个 block 需要 load 的 tile 数:
A: row 个 (每个 block 独立加载一行)
B: row × num_n 个 (每个 block 加载一列)
total = row × (1 + num_n)
Grouped (GROUP_SIZE_M = g):每组 g × num_n 个 block 内复用:
A: g × num_n / g = num_n 个 (一组内 A 行复用 num_n 次)
B: g 个 (一组内 B 列复用 g 次)
total = num_n + g
per-block ratio = (num_n + g) / (g × num_n)最优 g 满足 dL/dg = 0:
$$ g^* = \sqrt{N / BN} $$
即 g 大约取 √(N/BN)。
| N(输出列数) | BN | 推荐 GROUP_SIZE_M |
|---|---|---|
| 1024 | 128 | 3 |
| 4096 | 128 | 6 |
| 8192 | 128 | 8 |
| 16384 | 256 | 8 |
官方推荐统一用 GROUP_SIZE_M=8 是因为它在 N=4K~16K 的常见 LLM shape 下都接近最优。
6.9.4 从 hardware spec 反推最优参数的方法论
完整流程:
1. 查硬件 spec
├─ SMEM/SM (A100=164KB, H100=228KB, B200=256KB)
├─ Regs/SM (统一 65536)
├─ Max threads/SM (A100/H100=2048)
└─ HBM BW (A100=1.55TB/s, H100=3TB/s, B200=8TB/s)
2. 给定目标 occupancy(典型 4 blocks/SM 或 50%)
↓
3. 计算每 block 可用 SRAM 上限
↓
4. 反推 BM × BN × BK × num_stages × sizeof(dtype) ≤ 上限
↓
5. 在满足上限的组合里枚举 BM ∈ {64,128,256}, BN ∈ {64,128,256}, BK ∈ {32,64}
↓
6. 对每个 (BM, BN),按 4.9.2 计算合法 num_warps 取值
↓
7. 用 GROUP_SIZE_M = √(N/BN) 收敛唯一选择照此方法论生成的 config 列表通常只有 6~10 个,且每个都贴近硬件最优区。
6.10 prune_configs_by 高级实战
当 config 数量上 30 之后,剪枝的收益远大于"多跑几个 config"。
6.10.1 基于 hardware 限制的剪枝
def prune_by_smem(configs, named_args, **kwargs):
"""剪掉超过 A100 SMEM (164 KB) 的 config。"""
SMEM_LIMIT = 100 * 1024 # 留 60K 给 stack 和其他开销
dtype_bytes = 2 # fp16
pruned = []
for c in configs:
BM = c.kwargs['BLOCK_SIZE_M']
BN = c.kwargs['BLOCK_SIZE_N']
BK = c.kwargs['BLOCK_SIZE_K']
smem = c.num_stages * 2 * (BM * BK + BK * BN) * dtype_bytes
if smem <= SMEM_LIMIT:
pruned.append(c)
return pruned6.10.2 基于问题规模的剪枝
def prune_by_shape(configs, named_args, **kwargs):
"""小矩阵不用大 BLOCK;瘦长矩阵优先瘦长 BLOCK。"""
M, N, K = kwargs['M'], kwargs['N'], kwargs['K']
pruned = []
for c in configs:
BM = c.kwargs['BLOCK_SIZE_M']
BN = c.kwargs['BLOCK_SIZE_N']
# 规则 1:BLOCK 不能比矩阵大 2 倍以上(浪费)
if BM > 2 * M or BN > 2 * N:
continue
# 规则 2:M < 128 时只用 BM ≤ 64
if M < 128 and BM > 64:
continue
# 规则 3:方阵优先方块 tile(BM ≈ BN)
if abs(M - N) < 0.1 * max(M, N) and abs(BM - BN) > 64:
continue
pruned.append(c)
return pruned6.10.3 自定义 perf_model
性能模型不需要精确——只需排序大致正确即可:
def matmul_perf_model(M, N, K,
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
num_warps, num_stages, **kwargs):
"""返回预测的运行时间(任意单位,越小越快)。"""
# 1. 计算 wave 数:grid 大小 / SM 数
num_blocks = triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)
NUM_SM = 108 # A100
waves = num_blocks / NUM_SM
# 2. 单 block 的 FLOPs(fp16 Tensor Core 折算)
flops_per_block = 2 * BLOCK_SIZE_M * BLOCK_SIZE_N * K
# H100 Tensor Core fp16: ~989 TFLOPS / SM 数
tc_throughput = 989e12 / NUM_SM
compute_time_per_block = flops_per_block / tc_throughput
# 3. 单 block 的 HBM 流量
bytes_per_block = 2 * (BLOCK_SIZE_M * K + K * BLOCK_SIZE_N + BLOCK_SIZE_M * BLOCK_SIZE_N)
mem_time_per_block = bytes_per_block / (3e12 / NUM_SM) # H100 HBM
# 4. 取 max(compute-bound 或 memory-bound)
block_time = max(compute_time_per_block, mem_time_per_block)
# 5. 惩罚低 occupancy(粗略:num_warps 太多/太少都有问题)
occupancy_penalty = 1.0
if num_warps > 8: occupancy_penalty = 1.2
if num_warps < 4: occupancy_penalty = 1.3
return waves * block_time * occupancy_penalty
@triton.autotune(
configs=large_config_list, # 100+ configs
key=['M', 'N', 'K'],
prune_configs_by={
'perf_model': matmul_perf_model,
'top_k': 10, # 只 benchmark 前 10 名
},
)
@triton.jit
def matmul_kernel(...): ...6.10.4 实战:100+ → 10 configs 的剪枝
一个 LayerNorm kernel 的真实案例:
# 起始搜索空间(直接枚举)
ALL_CONFIGS = []
for BM in [64, 128, 256, 512]:
for BN in [64, 128, 256, 512, 1024]:
for warps in [2, 4, 8, 16]:
for stages in [1, 2, 3, 4]:
ALL_CONFIGS.append(
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN},
num_warps=warps, num_stages=stages))
# len = 4 × 5 × 4 × 4 = 320
def early_prune(configs, named_args, **kwargs):
N = kwargs['N'] # hidden_dim
pruned = []
for c in configs:
BM, BN = c.kwargs['BLOCK_M'], c.kwargs['BLOCK_N']
# 1. BN 必须 >= N(一行装进 SRAM)
if BN < N: continue
# 2. SMEM 限制
if BM * BN * 2 > 96 * 1024: continue
# 3. tile 总 size ≥ 4 × num_warps × 32 (Tensor Core 利用率)
if BM * BN < 128 * c.num_warps: continue
# 4. AMD 平台 num_stages ≤ 1
if not is_cuda() and c.num_stages > 1: continue
# 5. 去重:BM × num_warps 相同的只留 num_stages 最大的
pruned.append(c)
return pruned
@triton.autotune(
configs=ALL_CONFIGS,
key=['M', 'N'],
prune_configs_by={'early_config_prune': early_prune,
'perf_model': layernorm_perf_model,
'top_k': 8},
)
@triton.jit
def layernorm_kernel(...): ...实战效果:
| 阶段 | Config 数 | A100 autotune 耗时 |
|---|---|---|
| 朴素枚举 | 320 | ~12 分钟 |
| early_config_prune | ~25 | ~50 秒 |
| + perf_model top_k | 8 | ~15 秒 |
性能上:剪枝后选中的 best config 与朴素枚举选中的 吞吐差 < 2%——说明 perf_model 把高优先级 config 都保住了。
perf_model 不要追求精确
试图把 perf_model 写得很准(包含寄存器、warp 调度细节等)通常得不偿失——耗费一周建模换来的精度提升,远不如多列几个 config 让 autotune 实测。模型只需排序大致正确即可。
本章小结
- autotune 是必备工具:手工调参在多 shape × 多 GPU 场景下完全不可行;让编译器替你跑搜索。
@triton.autotune(configs=[...], key=[...])是最基本的用法:configs是候选列表,key决定何时重新调优。triton.Config的关键字段:BLOCK_SIZE_*、num_warps、num_stages,每个都直接影响 SRAM / 寄存器预算与并行度。- AMD vs NVIDIA:
num_stages语义不同,跨硬件必须分平台配置。 - 大调优空间用
prune_configs_by:基于硬件限制 + 问题规模 + perf_model 三层剪枝可以把 320 个 config 剪到 8 个,autotune 时间从 12 分钟降到 15 秒。 - 生产部署的 AOT 策略:warmup、
TRITON_CACHE_DIR持久化、IBM Dejavu 离线 config 三选一,避免首请求 autotune 抖动(实测可达 30~120 秒冷启动)。 - 搜索空间用数学建模设计:从硬件 spec 反推 occupancy、SRAM、num_warps 约束,能让 6~10 个 config 覆盖最优区。
- 常见坑:调优时间爆炸(最高 82 分钟)、benchmark 测不准(默认 warmup=25ms 偏低 30%)、忘了
reset_to_zero、跨硬件 config 混用。
掌握了"挑参数"之后,下一章我们要解决"省搬运"——算子融合把多个连续算子合并成一个核函数,让中间张量只在共享内存里流转,省掉 HBM 往返。
思考题
你为一个
M=K=4096, N动态变化的 GEMM 写了 autotune,key=['M', 'N', 'K']。在生产环境观察到首请求都很慢。请分析原因,并给出至少两种缓解方案。设计一个
prune_configs_by函数,对 matmul 的 config 列表剪枝:要求过滤掉BLOCK_SIZE_M × BLOCK_SIZE_N总元素数 > 32768 的配置(防止 SRAM 溢出)。给出函数签名与实现。你的核函数输出一个
[M, N]矩阵,每次调用会原地累加(即C += A @ B)。如果直接套 autotune,会出什么问题?如何修复?写出修改后的装饰器配置。