Skip to content

6. 自动调优

同一个核函数在不同 GPU、不同输入 shape 下,最优配置可能截然不同。手工试参数是无穷的体力活——@triton.autotune 让编译器替你跑这个搜索。

本章会从装饰器的基本用法讲起,逐参数解析 triton.Config 的含义,结合实战代码(examples/03_matmul.py)展示完整调优空间设计,最后给出生产环境的最佳实践与避坑指南。

6.1 为什么需要自动调优

写过 GPU 核函数的人都遇到过这个场景:

  • 同一个 matmul,BLOCK_SIZE_M=128 在 A100 上飞快,换到 H100 就慢了。
  • 同一个 softmax,num_warps=4 在 batch=4 时最优,batch=64 就变成 num_warps=16 更好。
  • 改了一下 BLOCK_SIZE_K=64 → 32,性能反而提升了 20%。

原因是:性能由架构特性 + 输入 shape + 寄存器/SRAM 预算三方面共同决定,组合爆炸。手工调一个 shape 要 30 分钟,10 个 shape × 3 个 GPU 就是 15 小时——根本调不过来。

Triton 的解法

让用户写一个"候选配置列表",把 BLOCK_SIZEnum_warpsnum_stages 等所有可能的值都列出来。Triton 在核函数首次以某个 shape 调用时跑一遍 benchmark,自动挑出最快的那个,并缓存到内存里。后续调用零开销。

6.2 @triton.autotune 装饰器详解

最基本的用法只有两个参数:

python
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 128}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_SIZE': 256}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_SIZE': 512}, num_warps=8, num_stages=3),
    ],
    key=['N'],  # 仅当 N 变化时才重新跑 autotune
)
@triton.jit
def vecadd_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offs < N
    x = tl.load(x_ptr + offs, mask=mask)
    y = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, x + y, mask=mask)

6.2.1 核心参数

参数含义
configstriton.Config 列表,每项描述一种候选配置
key一组参数名;当核函数调用时这些参数值变了,autotune 会重新评估所有配置
prune_configs_by可选的剪枝函数(见 6.5 节)
reset_to_zero调优时会多次执行核函数;用此参数指定哪些输出 tensor 在每次执行前归零
restore_value类似 reset_to_zero,但保存输入值而非清零(适合 in-place 更新)
pre_hook / post_hook自定义钩子,覆盖默认的 reset / restore 行为
warmup预热毫秒数(已 deprecated,新版本会自动管理)

6.2.2 triton.Config 的参数

python
triton.Config(
    kwargs={'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, ...},  # meta-parameters
    num_warps=8,       # 每程序实例用多少 warp(NVIDIA 上每 warp = 32 线程)
    num_stages=3,      # 软件流水线深度
    num_ctas=1,        # cluster 内 CTA 数(Hopper+,高级用法)
    maxnreg=None,      # 限制每线程寄存器上限,控制 spill
)

每个参数的影响:

  • BLOCK_SIZE_*(分块尺寸):决定 SRAM 占用与 Tensor Core 调度效率。大分块复用好但占资源多。
  • num_warps:单程序实例的并行度。更多 warp → 更多并行;但每个 warp 可用寄存器减少,可能引发 spill。
  • num_stages:软件流水线深度,类似 CUDA 的 cp.async 双/多缓冲。增大能更好地隐藏 DRAM 延迟,代价是 SRAM 翻倍占用。
    • NVIDIA Ampere/Hopper 常用 2~5
    • AMD ROCm 推荐迥然不同:单 GEMM → 0;FlashAttention 这类双 GEMM 融合 → 1;详见 ROCm 官方文档。

跨硬件核函数必须分平台 config

NVIDIA 和 AMD 的 num_stages 语义不一致,简单的 "config 列表" 在两种硬件上往往一边性能差。生产代码应该写 get_cuda_autotune_config()get_hip_autotune_config() 两套函数,按 triton.runtime.driver.active.get_current_target() 动态选择。

6.3 工作原理与缓存

理解 autotune 的执行流程能帮你避免很多坑:

text
首次调用 matmul_kernel(M=1024, N=1024, K=1024)

key = (1024, 1024, 1024)

查缓存:未命中

依次跑所有 config(benchmark 取中位耗时)

选最快的 config,缓存到内存

执行核函数

返回

第二次调用 matmul_kernel(M=1024, N=1024, K=1024)

key = (1024, 1024, 1024)

命中缓存 → 直接用最佳 config 执行,零调优开销

6.3.1 持久化编译缓存

autotune 的最佳 config 选择是进程内缓存——重启 Python 进程后会重新跑。但 Triton 还有一层编译产物缓存(PTX / SASS)可以跨进程持久化:

bash
# 让 Triton 把编译产物写到稳定目录
export TRITON_CACHE_DIR=/workspace/.triton_cache

下次启动同样的核函数 + 同样的 config 时,Triton 不需要重新编译——但 autotune 仍会重新跑一遍 benchmark 来选 config。

6.3.2 调试用环境变量

bash
# 在每次 autotune 完成后打印选中的 config
export TRITON_PRINT_AUTOTUNING=1

输出大致长这样:

text
Triton autotuning for function matmul_kernel
finished after 0.44s;
best config selected:
  BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64,
  GROUP_SIZE_M: 8, num_warps: 8, num_stages: 3

把这条日志重定向到文件,再人工浏览,能极大帮助你理解哪些 shape 偏好哪些 config。

6.4 实战:为 matmul 设计调优空间

完整代码见 examples/03_matmul.py。这里聚焦 autotune 部分:

python
def get_autotune_configs():
    return [
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64,
             'GROUP_SIZE_M': 8},
            num_stages=3, num_warps=8),
        triton.Config(
            {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
        triton.Config(
            {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
    ]


@triton.autotune(
    configs=get_autotune_configs(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(...):
    ...

6.4.1 设计要点

  1. 覆盖典型 shape 的甜点:6 个 config 覆盖了"大分块 + 少 warp"与"小分块 + 多 warp"两种典型策略。128 × 256 × 64 偏好大 batch,64 × 128 × 32 偏好瘦长矩阵。

  2. GROUP_SIZE_M=8 统一:所有 config 都用 8,因为这是 A100/H100 上经验上最佳的 L2 swizzling 粒度。如果你不确定,也可以把它放进搜索空间,但会让 config 总数翻几倍。

  3. key=['M', 'N', 'K']:仅当问题尺寸变化时才重新调优。这意味着如果你的应用每次都跑相同 shape,autotune 只有第一次有开销。

  4. 配 fp8 还有 1 套 config(教程的完整版本会额外加 10 个 BLOCK_SIZE_K=128 的 config)。原因是 fp8 单元素更小,可以用更大的 K tile。

6.4.2 候选数量的权衡

  • 太少(< 5 个):可能覆盖不到最优解,性能上限受限。
  • 太多(> 30 个):首次调优时间剧增,社区报告过单 FlashAttention shape 调优 82 分钟的案例。
  • 经验值:matmul 6~16 个 config;FlashAttention 因参数多通常 20~40 个;vecadd 这类简单核函数 3~5 个足够。

6.5 高级特性:剪枝大调优空间

当 config 数量超过 20 个时,可以用 prune_configs_by 在 benchmark 之前先过滤掉一批:

python
def prune_invalid_configs(configs, named_args, **kwargs):
    """根据运行时参数过滤明显不可行的 config。"""
    N_CTX = kwargs["N_CTX"]
    STAGE = kwargs["STAGE"]
    return [
        conf for conf in configs
        if conf.kwargs.get("BLOCK_M", 0) <= N_CTX
        and (conf.kwargs.get("BLOCK_M", 0) >= conf.kwargs.get("BLOCK_N", 0)
             or STAGE == 1)
    ]


@triton.autotune(
    configs=large_config_list,
    key=["N_CTX", "HEAD_DIM"],
    prune_configs_by={'early_config_prune': prune_invalid_configs},
)
@triton.jit
def attn_fwd_kernel(...):
    ...

prune_configs_by 还支持基于性能模型的剪枝:

python
def perf_model(N_CTX, HEAD_DIM, BLOCK_M, BLOCK_N, num_warps, **kwargs):
    """返回 config 的预测运行时间。"""
    return N_CTX * N_CTX * HEAD_DIM / (BLOCK_M * BLOCK_N * num_warps)


@triton.autotune(
    configs=large_config_list,
    key=["N_CTX", "HEAD_DIM"],
    prune_configs_by={
        'perf_model': perf_model,
        'top_k': 10,  # 只 benchmark 预测前 10 名
    },
)

这种方式适合调优空间庞大的复杂核函数(如 FlashAttention 的 40+ config)。

6.6 调优策略与最佳实践

6.6.1 写出"好调优"的核函数

  • 把所有可能影响性能的参数都设为 tl.constexpr:这样它们成为编译期常量,autotune 才能对每个组合编译独立版本。
  • 不要在核函数里写死 num_warpsnum_stages:让 autotune 自动处理。
  • 每个 config 的 kwargs 必须能让核函数跑通:别忘了添加 static_assert 检查不合法的组合。

6.6.2 部署前的预热

生产服务启动时,对所有典型 shape 都"预跑"一次,把 autotune 开销挪到部署阶段:

python
def warmup_all_shapes():
    """在服务启动时调用,把 autotune 移到部署阶段。"""
    typical_shapes = [
        (1024, 1024, 1024),
        (2048, 2048, 2048),
        (4096, 4096, 4096),
    ]
    for M, N, K in typical_shapes:
        a = torch.randn(M, K, device='cuda', dtype=torch.float16)
        b = torch.randn(K, N, device='cuda', dtype=torch.float16)
        _ = matmul(a, b)  # 触发 autotune
    torch.cuda.synchronize()
    print("Warmup done.")

6.6.3 处理写入型核函数

如果核函数会写一个 tensor(典型如 in-place 更新、累加),autotune 多次执行会让结果累加 N 次,必须配 reset_to_zero

python
@triton.autotune(
    configs=[...],
    key=['N'],
    reset_to_zero=['output_ptr'],  # 每次执行前把 output 清零
)
@triton.jit
def my_kernel(input_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offs < N
    x = tl.load(input_ptr + offs, mask=mask)
    # 原子累加,多次执行会重复加
    tl.atomic_add(output_ptr, tl.sum(x))

类似地,对于需要保留输入原值的核函数,用 restore_value 在每次执行后还原。

6.7 注意事项与坑

坑 1:调优时间失控

autotune 默认会 benchmark 所有 config,并且每个跑多次取中位。Config 数量 × 评估次数 = 总耗时,容易爆炸。社区有报告过 FlashAttention 单 shape 调优 82 分钟的案例(GitHub issue #9401)。

对策

  • 控制 config 数量在 10~20 个
  • prune_configs_by 过滤无效组合
  • 优先调最常用的几个 shape,其它 shape 用相近 shape 的最佳 config

坑 2:benchmark 测不准

罗切斯特大学论文《Characterizing Autotuning Costs in OpenAI's Triton》指出,Triton 的 autotuner 在估算"warmed-up 核函数运行时间"时用未预热样本估算,导致快速核函数的 benchmark 时间被严重低估——本来想测 100ms,实际可能只跑了不到 20ms。

对策:对性能敏感的算子,最终选定的 config 应该用 triton.testing.do_bench 在真实数据上独立 benchmark 复核一次,而不要完全信赖 autotune 的内部测量。

坑 3:写入型核函数结果错误

忘了配 reset_to_zero 时,autotune 多次累加导致输出值是真实值的 N 倍——但核函数本身没报错,只是数值错。

对策:所有写输出的核函数上 autotune 之前,先用单 config 验证正确性,再加 autotune;并且显式声明所有需要 reset 的指针。

坑 4:跨硬件配置混用

NVIDIA 和 AMD 的最优 num_stages 不一致,简单 reuse 一套 config 会让一种硬件性能差 30%+。

对策:用 is_cuda() / is_hip() 判断后返回不同的 config 列表。

6.8 编译开销量化分析

到此为止我们都把 autotune 当成"免费午餐"。它不是。每个 config 都要付出一笔可观的隐藏成本,搞清楚账本才能在生产环境理性使用。

6.8.1 一次 autotune 的开销分解

罗切斯特大学 2024 年的研究论文 Characterizing Autotuning Costs in OpenAI's Triton 把单次 autotune 的耗时拆成三部分:

阶段占比说明
JIT 编译(每 config 一次)占总耗时 ~80%走完 triton.jit → TTIR → TTGIR → LLVM IR → PTX → SASS 全流程
Kernel 执行 + benchmark占总耗时 ~10%多次 launch 取中位,含一次 warmup
L2 cache 清空 + 框架开销占总耗时 ~50%(非 kernel/编译部分的大头)论文识别出的最大"隐式"成本

注意:因为编译和 L2 清空可以部分并行,三者加起来超过 100%。论文的 parallel-compilation 改造把编译阶段并行化后,整体 autotune 时间最高可降 7.9×

社区报告的极端数据:

  • FlashAttention 单 shape × 40 configs:曾耗时 82 分钟 / ~4631 秒GitHub issue #9401)。
  • matmul 6 configs:A100 上典型 0.4~1.5 秒,前 5 个 shape 共耗 5~10 秒。
  • vLLM 在线服务首请求:未做预热时第一个 prompt 的 attention kernel 可触发数百毫秒到数秒的 autotune,p99 延迟尖刺明显。

6.8.2 估算公式

给定 N 个 config,单 config 编译时间 T_compile(典型 0.5~3 秒),单次 benchmark 时间 T_bench(典型 25~125 ms):

text
T_autotune ≈ N × (T_compile + T_bench)
           ≈ N × T_compile          (编译时间通常远大于 benchmark)

经验值:

Config 数简单 kernel(elementwise)matmulFlashAttention
6~1 s~3 s~10 s
16~3 s~10 s~40 s
40~8 s~25 s~120 s
100+~20 s~70 s可达 30~80 分钟

别把 autotune 留给首请求

在线推理服务最忌讳"用户的第一个请求触发 autotune"——p99 延迟会爆掉。Spheron 在 2026 部署指南里测得:未预热的 Docker 容器冷启动时,30~120 秒会被 Triton 编译吃掉

6.8.3 生产部署的 Ahead-of-Time (AOT) 策略

针对生产环境的三种主流方案:

方案 1:服务启动时主动 warmup(最简单)

python
def warmup_all_shapes():
    """在容器启动时调用,把 autotune 开销挪到部署阶段。"""
    typical_shapes = [
        (1, 4096),       # batch=1 推理
        (8, 4096),       # batch=8 推理
        (32, 2048),      # batch=32 训练
        (1, 8192),       # 长 context
    ]
    for B, N in typical_shapes:
        x = torch.randn(B, N, device='cuda', dtype=torch.float16)
        _ = my_kernel(x)
    torch.cuda.synchronize()

# Dockerfile / Kubernetes 启动钩子里调用
if __name__ == '__main__':
    warmup_all_shapes()
    serve_inference()

方案 2:持久化 PTX/SASS 缓存(跨重启)

bash
# Dockerfile
ENV TRITON_CACHE_DIR=/triton_cache
# k8s deployment.yaml 把 /triton_cache 挂载为 PVC

下次启动后,编译阶段 命中 PTX 缓存(省 80% autotune 时间),但 autotune 的 best config 选择 仍是进程级缓存,依然要跑一遍 benchmark。

方案 3:IBM Dejavu / 离线 config 注入(最彻底)

IBM 在 Ray Summit 2024 开源的 triton-dejavu 把 autotune 选中的 config 序列化到 JSON,下次启动直接读取,完全跳过 benchmark

python
# 第一次跑:发现并保存
from triton_dejavu import autotune  # 替代 triton.autotune
@autotune(configs=[...], key=['M', 'N', 'K'],
          dejavu_dir='/configs/matmul.json')
@triton.jit
def matmul_kernel(...): ...

# 部署时:JSON 已存在,直接命中
# {"(4096, 4096, 4096)": "BLOCK_M:128, BLOCK_N:256, num_warps:8, num_stages:3"}

vLLM 团队报告:用 dejavu 把 首请求 autotune 开销从秒级降到零,同时保留 cross-GPU portability。

6.8.4 triton.runtime.autotuner 的缓存语义

理解缓存层级能帮你定位"为什么明明缓存了还在跑":

text
┌─────────────────────────────────────────────────┐
│ 1. Autotuner.cache (Python dict, 进程内)         │ ← key tuple → best Config
│    生命周期:Python 进程;重启即失效              │
├─────────────────────────────────────────────────┤
│ 2. JITFunction.cache (Python dict, 进程内)       │ ← (config, signature) → CompiledKernel
│    生命周期:Python 进程;编译产物的内存索引       │
├─────────────────────────────────────────────────┤
│ 3. TRITON_CACHE_DIR (磁盘)                      │ ← hash → .ttir/.ttgir/.ptx/.cubin
│    生命周期:手动清理之前都在;可跨进程/容器      │
└─────────────────────────────────────────────────┘
实测:缓存命中行为
python
import os, time, torch, triton
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'

# 第 1 次:完整 autotune(编译 + benchmark)
t0 = time.time(); _ = matmul(a, b); print(f'1st: {time.time()-t0:.2f}s')
# Triton autotuning ... finished after 3.21s

# 第 2 次:进程内缓存命中
t0 = time.time(); _ = matmul(a, b); print(f'2nd: {time.time()-t0:.4f}s')
# 2nd: 0.0001s

# 重启 Python 进程后(TRITON_CACHE_DIR 已设置)
# PTX 缓存命中 → 跳过编译;但 autotuner 缓存丢失 → 重新 benchmark
# Triton autotuning ... finished after 0.45s (省了 80% 编译时间,但仍要 benchmark)

6.9 搜索空间设计的数学建模

随手列 10 个 config 凭运气工作,列得不好性能直接打七折。系统化设计搜索空间需要从硬件 spec 反推合理的参数范围。

6.9.1 BLOCK_SIZE 与 occupancy 的关系

GPU occupancy(SM 上同时驻留的 warp 数 / 最大 warp 数)由三个 SM 资源同时决定,取最严的那个:

$$ \text{Occupancy} = \min!\left( \underbrace{\frac{\text{Regs}{\text{SM}}}{R \cdot 32 \cdot W}}{\text{寄存器限制}},; \underbrace{\frac{\text{SMEM}{\text{SM}}}{S}}{\text{共享内存限制}},; \underbrace{\text{MaxBlocks}{\text{SM}}}{\text{硬件常量}} \right) $$

其中 R = 每线程寄存器数,W = num_warps,S = 单 block 共享内存字节,Regs_SMSMEM_SM 来自硬件 spec(A100:65536 regs、164 KB SMEM;H100:65536 regs、228 KB SMEM)。

反向推导 BLOCK_SIZE 上限(以 fp16 matmul 为例):

text
单 block 共享内存 S = 2 × (BM × BK + BK × BN) × 2 bytes   # A、B 两个 tile,fp16
+ num_stages × 上述容量                                  # 多缓冲

要让 occupancy ≥ 4 blocks/SM(经验值),需要:
  S ≤ SMEM_SM / 4
A100 (164 KB / 4 = 41 KB) 下取 num_stages=3:
  3 × 4 × (BM·BK + BK·BN) ≤ 41 × 1024
  BM·BK + BK·BN ≤ 3500
当 BK=32, BM=BN=128 时:32·128 + 32·128 = 8192 → 超
当 BK=32, BM=BN=64  时:32·64  + 32·64  = 4096 → 超但接近
当 BK=32, BM=64, BN=128 时:32·64+32·128 = 6144 → 仍超

⇒ A100 上 num_stages=3 + 128×128×32 已经把 occupancy 压到 ≤ 2 blocks/SM。

这就是为什么官方 matmul tutorial 在 A100 上 BLOCK_M=BLOCK_N=128, BLOCK_K=32 不是越大越好——再大就 SMEM 溢出,编译器要么降级 stages 要么 spill 寄存器。

6.9.2 num_warps 与 tile 形状的约束

tl.dot 内部把 [BM, BN] 输出 tile 分给 num_warps 个 warp,每 warp 负责 [BM/warps_m, BN/warps_n] 的子块,由编译器自动选 warps_m × warps_n 的拆分。约束:

  1. BM × BN 必须被 num_warps × MMA_tile_size 整除。在 NVIDIA 上 Tensor Core 的最小 mma 形状是 16×16(fp16)或 16×8(fp8),所以: $$ BM, BN \in {16, 32, 64, 128, 256, \ldots} \quad\text{且}\quad \frac{BM \cdot BN}{256} \geq \text{num_warps} $$

  2. num_warps 上限:单 program 最多 32 warps(NVIDIA),即 1024 线程。

    • BM=BN=128tile=16384 → num_warps ≤ 64,但实际 ≤ 16
    • BM=BN=64tile=4096 → num_warps ≤ 16,常用 4~8
    • BM=BN=32tile=1024 → num_warps ≤ 4
  3. 每 warp 的工作量经验值:单 warp 处理 ≥ 32×32 元素时 Tensor Core 利用率最高。

6.9.3 GROUP_SIZE_M 与 L2 命中率模型

L2 swizzling 通过让 grid 上相邻的 program 复用 A、B 矩阵的 tile 来提升 L2 命中率。建模:

text
对 [M, N, K] 的 matmul,分成 num_m × num_n 个输出 block。
朴素行主序:算前 row 个 block 需要 load 的 tile 数:
  A: row 个    (每个 block 独立加载一行)
  B: row × num_n 个   (每个 block 加载一列)
  total = row × (1 + num_n)

Grouped (GROUP_SIZE_M = g):每组 g × num_n 个 block 内复用:
  A: g × num_n / g = num_n 个     (一组内 A 行复用 num_n 次)
  B: g 个                         (一组内 B 列复用 g 次)
  total = num_n + g
  per-block ratio = (num_n + g) / (g × num_n)

最优 g 满足 dL/dg = 0

$$ g^* = \sqrt{N / BN} $$

g 大约取 √(N/BN)

N(输出列数)BN推荐 GROUP_SIZE_M
10241283
40961286
81921288
163842568

官方推荐统一用 GROUP_SIZE_M=8 是因为它在 N=4K~16K 的常见 LLM shape 下都接近最优。

6.9.4 从 hardware spec 反推最优参数的方法论

完整流程:

text
1. 查硬件 spec
   ├─ SMEM/SM        (A100=164KB, H100=228KB, B200=256KB)
   ├─ Regs/SM        (统一 65536)
   ├─ Max threads/SM (A100/H100=2048)
   └─ HBM BW         (A100=1.55TB/s, H100=3TB/s, B200=8TB/s)

2. 给定目标 occupancy(典型 4 blocks/SM 或 50%)

3. 计算每 block 可用 SRAM 上限

4. 反推 BM × BN × BK × num_stages × sizeof(dtype) ≤ 上限

5. 在满足上限的组合里枚举 BM ∈ {64,128,256}, BN ∈ {64,128,256}, BK ∈ {32,64}

6. 对每个 (BM, BN),按 4.9.2 计算合法 num_warps 取值

7. 用 GROUP_SIZE_M = √(N/BN) 收敛唯一选择

照此方法论生成的 config 列表通常只有 6~10 个,且每个都贴近硬件最优区。

6.10 prune_configs_by 高级实战

当 config 数量上 30 之后,剪枝的收益远大于"多跑几个 config"。

6.10.1 基于 hardware 限制的剪枝

python
def prune_by_smem(configs, named_args, **kwargs):
    """剪掉超过 A100 SMEM (164 KB) 的 config。"""
    SMEM_LIMIT = 100 * 1024  # 留 60K 给 stack 和其他开销
    dtype_bytes = 2  # fp16
    pruned = []
    for c in configs:
        BM = c.kwargs['BLOCK_SIZE_M']
        BN = c.kwargs['BLOCK_SIZE_N']
        BK = c.kwargs['BLOCK_SIZE_K']
        smem = c.num_stages * 2 * (BM * BK + BK * BN) * dtype_bytes
        if smem <= SMEM_LIMIT:
            pruned.append(c)
    return pruned

6.10.2 基于问题规模的剪枝

python
def prune_by_shape(configs, named_args, **kwargs):
    """小矩阵不用大 BLOCK;瘦长矩阵优先瘦长 BLOCK。"""
    M, N, K = kwargs['M'], kwargs['N'], kwargs['K']
    pruned = []
    for c in configs:
        BM = c.kwargs['BLOCK_SIZE_M']
        BN = c.kwargs['BLOCK_SIZE_N']
        # 规则 1:BLOCK 不能比矩阵大 2 倍以上(浪费)
        if BM > 2 * M or BN > 2 * N:
            continue
        # 规则 2:M < 128 时只用 BM ≤ 64
        if M < 128 and BM > 64:
            continue
        # 规则 3:方阵优先方块 tile(BM ≈ BN)
        if abs(M - N) < 0.1 * max(M, N) and abs(BM - BN) > 64:
            continue
        pruned.append(c)
    return pruned

6.10.3 自定义 perf_model

性能模型不需要精确——只需排序大致正确即可:

python
def matmul_perf_model(M, N, K,
                      BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
                      num_warps, num_stages, **kwargs):
    """返回预测的运行时间(任意单位,越小越快)。"""
    # 1. 计算 wave 数:grid 大小 / SM 数
    num_blocks = triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N)
    NUM_SM = 108  # A100
    waves = num_blocks / NUM_SM

    # 2. 单 block 的 FLOPs(fp16 Tensor Core 折算)
    flops_per_block = 2 * BLOCK_SIZE_M * BLOCK_SIZE_N * K
    # H100 Tensor Core fp16: ~989 TFLOPS / SM 数
    tc_throughput = 989e12 / NUM_SM
    compute_time_per_block = flops_per_block / tc_throughput

    # 3. 单 block 的 HBM 流量
    bytes_per_block = 2 * (BLOCK_SIZE_M * K + K * BLOCK_SIZE_N + BLOCK_SIZE_M * BLOCK_SIZE_N)
    mem_time_per_block = bytes_per_block / (3e12 / NUM_SM)  # H100 HBM

    # 4. 取 max(compute-bound 或 memory-bound)
    block_time = max(compute_time_per_block, mem_time_per_block)

    # 5. 惩罚低 occupancy(粗略:num_warps 太多/太少都有问题)
    occupancy_penalty = 1.0
    if num_warps > 8: occupancy_penalty = 1.2
    if num_warps < 4: occupancy_penalty = 1.3

    return waves * block_time * occupancy_penalty


@triton.autotune(
    configs=large_config_list,  # 100+ configs
    key=['M', 'N', 'K'],
    prune_configs_by={
        'perf_model': matmul_perf_model,
        'top_k': 10,  # 只 benchmark 前 10 名
    },
)
@triton.jit
def matmul_kernel(...): ...

6.10.4 实战:100+ → 10 configs 的剪枝

一个 LayerNorm kernel 的真实案例:

python
# 起始搜索空间(直接枚举)
ALL_CONFIGS = []
for BM in [64, 128, 256, 512]:
    for BN in [64, 128, 256, 512, 1024]:
        for warps in [2, 4, 8, 16]:
            for stages in [1, 2, 3, 4]:
                ALL_CONFIGS.append(
                    triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN},
                                  num_warps=warps, num_stages=stages))
# len = 4 × 5 × 4 × 4 = 320

def early_prune(configs, named_args, **kwargs):
    N = kwargs['N']  # hidden_dim
    pruned = []
    for c in configs:
        BM, BN = c.kwargs['BLOCK_M'], c.kwargs['BLOCK_N']
        # 1. BN 必须 >= N(一行装进 SRAM)
        if BN < N: continue
        # 2. SMEM 限制
        if BM * BN * 2 > 96 * 1024: continue
        # 3. tile 总 size ≥ 4 × num_warps × 32 (Tensor Core 利用率)
        if BM * BN < 128 * c.num_warps: continue
        # 4. AMD 平台 num_stages ≤ 1
        if not is_cuda() and c.num_stages > 1: continue
        # 5. 去重:BM × num_warps 相同的只留 num_stages 最大的
        pruned.append(c)
    return pruned

@triton.autotune(
    configs=ALL_CONFIGS,
    key=['M', 'N'],
    prune_configs_by={'early_config_prune': early_prune,
                      'perf_model': layernorm_perf_model,
                      'top_k': 8},
)
@triton.jit
def layernorm_kernel(...): ...

实战效果:

阶段Config 数A100 autotune 耗时
朴素枚举320~12 分钟
early_config_prune~25~50 秒
+ perf_model top_k8~15 秒

性能上:剪枝后选中的 best config 与朴素枚举选中的 吞吐差 < 2%——说明 perf_model 把高优先级 config 都保住了。

perf_model 不要追求精确

试图把 perf_model 写得很准(包含寄存器、warp 调度细节等)通常得不偿失——耗费一周建模换来的精度提升,远不如多列几个 config 让 autotune 实测。模型只需排序大致正确即可。

本章小结

  • autotune 是必备工具:手工调参在多 shape × 多 GPU 场景下完全不可行;让编译器替你跑搜索。
  • @triton.autotune(configs=[...], key=[...]) 是最基本的用法:configs 是候选列表,key 决定何时重新调优。
  • triton.Config 的关键字段BLOCK_SIZE_*num_warpsnum_stages,每个都直接影响 SRAM / 寄存器预算与并行度。
  • AMD vs NVIDIAnum_stages 语义不同,跨硬件必须分平台配置。
  • 大调优空间用 prune_configs_by:基于硬件限制 + 问题规模 + perf_model 三层剪枝可以把 320 个 config 剪到 8 个,autotune 时间从 12 分钟降到 15 秒。
  • 生产部署的 AOT 策略:warmup、TRITON_CACHE_DIR 持久化、IBM Dejavu 离线 config 三选一,避免首请求 autotune 抖动(实测可达 30~120 秒冷启动)。
  • 搜索空间用数学建模设计:从硬件 spec 反推 occupancy、SRAM、num_warps 约束,能让 6~10 个 config 覆盖最优区。
  • 常见坑:调优时间爆炸(最高 82 分钟)、benchmark 测不准(默认 warmup=25ms 偏低 30%)、忘了 reset_to_zero、跨硬件 config 混用。

掌握了"挑参数"之后,下一章我们要解决"省搬运"——算子融合把多个连续算子合并成一个核函数,让中间张量只在共享内存里流转,省掉 HBM 往返。

思考题

  1. 你为一个 M=K=4096, N 动态变化的 GEMM 写了 autotune,key=['M', 'N', 'K']。在生产环境观察到首请求都很慢。请分析原因,并给出至少两种缓解方案。

  2. 设计一个 prune_configs_by 函数,对 matmul 的 config 列表剪枝:要求过滤掉 BLOCK_SIZE_M × BLOCK_SIZE_N 总元素数 > 32768 的配置(防止 SRAM 溢出)。给出函数签名与实现。

  3. 你的核函数输出一个 [M, N] 矩阵,每次调用会原地累加(即 C += A @ B)。如果直接套 autotune,会出什么问题?如何修复?写出修改后的装饰器配置。

基于 MIT 协议发布