4. 基础算子:从向量加法开始
本章你将动手写出第一个真正能跑的 Triton 核函数。我们以向量加法为例,把第 3 章学到的概念全部落地:
@triton.jit、tl.program_id、tl.arange、tl.load/store、网格启动、与 PyTorch 对照验证、做性能 benchmark。
4.1 任务定义
最简单的逐元素运算:给定两个长度为 N 的向量 x、y,计算 output = x + y。
import torch
x = torch.rand(1_000_000, device='cuda')
y = torch.rand(1_000_000, device='cuda')
output = x + y # PyTorch 一行搞定PyTorch 已经够快了,那为什么我们要费劲用 Triton 重写一遍?
- 这是最经典的入门示例,对应 Triton 官方教程
01-vector-add,能完整覆盖 Triton 的核心 API - 它揭示了 Triton 写所有逐元素算子的通用骨架——后续的 LayerNorm、softmax、激活函数都是它的变体
期望管理
向量加法是完全 memory-bound 的算子(每个元素只做一次加法,带宽几乎决定一切)。Triton 写出来的版本不会比 PyTorch 更快,能持平就是成功。它的价值在于教学,不在于性能。
4.2 @triton.jit 装饰器
import triton
import triton.language as tl
@triton.jit
def my_kernel(x_ptr, ..., BLOCK_SIZE: tl.constexpr):
...@triton.jit 做了三件事:
- 标记这个 Python 函数将被编译为 GPU 核函数
- 解析函数体为 Triton 子集语法(不是普通 Python——你不能在里面
print()、不能import、不能用任意 Python 库) - JIT 编译:首次以特定
(dtype, constexpr 值, ...)组合调用时,把它编译成针对该组合特化的 PTX/CUBIN;结果缓存到~/.triton/cache/
不同的 constexpr 值 = 不同的编译产物
如果你用 BLOCK_SIZE=1024 和 BLOCK_SIZE=2048 各调一次,会触发两次JIT 编译,得到两个独立的二进制。这是 Triton 高性能的关键——每种配置都是单独特化的,没有运行时分支。
4.3 tl.load / tl.store:与 HBM 交换数据的唯一通道
# 加载:把 BLOCK_SIZE 个元素从 HBM 拉进 SRAM/寄存器
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
# 存储:把 value 写回 HBM
tl.store(out_ptr + offsets, value, mask=mask)参数:
| 参数 | 含义 |
|---|---|
| 第一个参数 | 指针向量(基址指针 + 偏移整数向量) |
mask | bool 向量,False 处不读/写 |
other | (仅 load)mask 为 False 处填的默认值,默认 0 |
性能要点:连续访问
如果 offsets 是 [0, 1, 2, 3, ...] 这样连续的整数序列,编译器会自动生成"向量化加载"指令(一条 PTX 指令搬运多个元素),充分利用 HBM 带宽——这就是"自动内存合并 (coalescing)"。
如果 offsets 是 [0, 100, 200, 300, ...] 这样跳跃的,每个元素要单独发请求,带宽利用率会暴跌。让 offsets 尽量连续是写高性能 Triton 算子的第一守则。
4.4 完整代码:向量加法
下面是一个端到端的、可以直接 python 运行的版本(完整代码见 examples/01_vector_add.py)。我们分步骤拆解。
4.4.1 核函数部分
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@triton.jit
def add_kernel(
x_ptr, # 输入张量 x 的指针(指向首元素)
y_ptr, # 输入张量 y 的指针
output_ptr, # 输出张量的指针
n_elements, # 向量总长度(运行期标量)
BLOCK_SIZE: tl.constexpr, # 每个程序实例处理多少元素;编译期常量,必须是 2 的幂
):
# ---- 步骤 1:拿到当前程序实例在 1D 网格中的 ID ----
pid = tl.program_id(axis=0)
# ---- 步骤 2:计算本程序实例负责的全局偏移 ----
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# ---- 步骤 3:构造越界 mask ----
mask = offsets < n_elements
# ---- 步骤 4:从 HBM 加载两个输入块 ----
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# ---- 步骤 5:在寄存器内完成逐元素加法 ----
output = x + y
# ---- 步骤 6:把结果写回 HBM ----
tl.store(output_ptr + offsets, output, mask=mask)4.4.2 逐步讲解
步骤 1:pid = tl.program_id(axis=0)
每个程序实例拿到自己唯一的 ID。我们用 1D 网格,所以只查 axis=0。
设 N=1000、BLOCK_SIZE=128,会启动 ceil(1000/128) = 8 个程序实例,pid 取值范围 [0, 1, 2, ..., 7]。
步骤 2:构造 offsets
block_start = pid * BLOCK_SIZE # 标量:本程序实例对应的起始下标
offsets = block_start + tl.arange(0, BLOCK_SIZE) # 向量:本程序实例要处理的所有下标举例 pid=2, BLOCK_SIZE=128:
block_start = 256
tl.arange(0, 128) = [0, 1, 2, ..., 127]
offsets = [256, 257, 258, ..., 383]步骤 3:mask
mask = offsets < n_elements最后一个程序实例(pid=7, offsets = [896, ..., 1023])中,下标 1000-1023 都越界。mask 在这些位置为 False,后续 load/store 会自动忽略它们。
步骤 4:load
x = tl.load(x_ptr + offsets, mask=mask)x_ptr是 PyTorch 自动转换给 Triton 的"指向 x 首元素的指针"x_ptr + offsets得到一个指针向量(128 个指针)tl.load一次性把 128 个 float 从 HBM 拉到 SRAM/寄存器
步骤 5:加法
output = x + y这是分块级运算——x 和 y 是形状 (128,) 的 Triton tensor,+ 是逐元素加。底层会被编译为多条 SIMD 加法指令。
步骤 6:store
tl.store(output_ptr + offsets, output, mask=mask)mask 在这里同样关键,避免最后一个程序实例把垃圾数据写到 output 之外的内存。
4.4.3 Python wrapper:让核函数像普通函数一样被调用
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert x.is_cuda and y.is_cuda, "输入张量必须在 GPU 上"
assert x.shape == y.shape, f"形状不匹配:{x.shape} vs {y.shape}"
output = torch.empty_like(x)
n_elements = output.numel()
# 网格用 lambda 延迟求值
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# 启动核函数
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output四个关键点:
- 预分配输出:Triton 核函数只会"原地写入",不会返回值。要先
torch.empty_like(x)拿到 output 张量。 - 网格用 lambda:好处是改 BLOCK_SIZE 时网格自动适配。
meta是包含所有 constexpr 参数的字典。 triton.cdiv:向上取整除法 =(a + b - 1) // b。- 异步启动:
add_kernel[grid](...)返回得很快,核函数还在 GPU 上跑。后续 PyTorch 用到 output 时(比如.item()、print())会自动同步。
4.4.4 正确性验证
torch.manual_seed(0)
size = 98_432 # 故意选不是 1024 整数倍的尺寸,测 mask
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add(x, y)
max_diff = torch.max(torch.abs(output_torch - output_triton)).item()
assert max_diff < 1e-6, f"差异过大:{max_diff}"
print(f"[PASS] 最大误差 = {max_diff:.3e}")永远用"非整除"尺寸测 mask
新手写 Triton 最常见的 bug 就是忘了 mask。但如果你只用 size = 1024 * 100 这种整除尺寸测试,有 bug 也测不出来——因为最后一个程序实例没有越界。故意选 98432、12345 这样的非整除尺寸,才是有效的回归测试。
4.4.5 性能基准:用 triton.testing.do_bench
Triton 自带一套 benchmark 工具,能直接画出性能对比图:
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'],
x_vals=[2 ** i for i in range(12, 28, 1)], # 4K ~ 128M 元素
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='GB/s',
plot_name='vector-add-performance',
args={},
)
)
def benchmark(size, provider):
x = torch.rand(size, device=DEVICE, dtype=torch.float32)
y = torch.rand(size, device=DEVICE, dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
else:
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
# 向量加法的有效带宽 = (读 2 个输入 + 写 1 个输出) × 元素字节数 / 时间
gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(print_data=True, show_plots=False)在 A100 / H100 这类卡上,你会看到 Triton 和 Torch 两条线几乎重合——这正是 memory-bound 算子的预期。
4.5 这个骨架适用于所有逐元素算子
把上面的核函数中 output = x + y 这一行换成别的,立刻得到一系列新算子:
output = x * y # 逐元素乘
output = x * x + y * y # 平方和
output = tl.sigmoid(x) # sigmoid
output = tl.maximum(x, 0.0) # ReLU
output = x * tl.sigmoid(x) # SiLU / Swish
output = tl.where(mask_cond, x, y) # 条件选择整个 load → compute → store 的骨架完全不变。这就是 Triton 的威力:一旦掌握骨架,所有逐元素算子都是十分钟内可以写完的小事。
4.6 完整可运行代码
完整代码(含 benchmark 与 ASCII 表格输出)见 examples/01_vector_add.py。运行方式:
cd examples
python 01_vector_add.py预期输出(A100 为例):
============================================================
Triton 示例 01:向量加法
============================================================
[正确性] size=98432
torch 前 5 个: [...]
triton 前 5 个: [...]
最大绝对误差: 0.000e+00
[PASS]
[Benchmark] 跑带宽对比(结果以 GB/s 为单位,越高越好)
vector-add-performance:
size Triton Torch
4096.0 89.7 88.3
...
67108864.0 1413.2 1410.14.7 性能分析方法论:把 GB/s 转成"是否打满"
前面跑出的 benchmark 输出"Triton 1413 GB/s、Torch 1410 GB/s"。一个值得训练的肌肉记忆是:看到数字立刻问"距离理论上限还有多远"。
4.7.1 算理论带宽上限
向量加法每个元素读 2 次、写 1 次,总 IO 量:
total_bytes = 3 × N × sizeof(element)测得 kernel 用时 t 秒后:
achieved_BW (GB/s) = total_bytes / t / 1e9
bandwidth_utilization = achieved_BW / peak_HBM_BW各代 NVIDIA 卡的 peak HBM 带宽(参考值):
| GPU | HBM 类型 | Peak BW |
|---|---|---|
| V100 32 GB | HBM2 | 900 GB/s |
| A100 40 GB | HBM2 | 1555 GB/s |
| A100 80 GB SXM | HBM2e | 2039 GB/s |
| H100 80 GB SXM | HBM3 | ~3350 GB/s |
| H200 141 GB | HBM3e | ~4800 GB/s |
写入 out_ptr 的 1× 不是免费的
完整账:x 读 1×、y 读 1×、out 写 1×,但 GPU 还会在写之前"读一下" out 的对应 cache line(write-allocate 协议),所以一些教科书会用 4×N 而不是 3×N。HBM 的 streaming store 通道大多数情况下能优化掉这次 RFO,所以业内 benchmark 普遍用 3×N,与官方 Triton tutorial 一致。
4.7.2 一个完整的算例
假设 A100 80GB SXM 上跑 N = 64 M (即 67108864) fp32,用 triton.testing.do_bench 测得 t = 0.142 ms:
total_bytes = 3 × 67_108_864 × 4 = 805_306_368 B ≈ 805 MB
achieved_BW = 805_306_368 / (0.142e-3) / 1e9 = 5670 GB/s ???等等——这个数字超过了 A100 的 2039 GB/s peak,说明数据在 L2 cache 里命中,并不是真正打 HBM。要测真实带宽,必须用大于 L2 容量(A100 = 40 MB)的输入:
N_safe = 256_000_000 (256M fp32 = 1 GB,远大于 L2)实测:A100 上向量加法稳定在 1900~1950 GB/s,bandwidth utilization ≈ 93~95%——这已经是 memory-bound 算子的天花板。
benchmark 易错点
- N 必须大于 L2:否则你只是在测 L2 带宽(A100 ~2.2 TB/s、H100 ~5.5 TB/s),与 HBM 带宽差好几倍。
- 要预热:第一次调用包含 JIT 编译时间,
do_bench默认会预热但自定义循环要手动for _ in range(10): kernel(...)。 - 同步:
torch.cuda.synchronize()包在测时器内外的位置错了,会把异步 launch 时间误算进 kernel 时间。
4.7.3 arithmetic intensity:为什么向量加法永远 memory-bound
定义:
AI (FLOP/byte) = 算术运算数 / 实际 IO 字节数向量加法:
- FLOP = 1 (一次加法) per element
- IO = 12 bytes per element (fp32: 3 × 4 B)
- AI = 1 / 12 ≈ 0.083 FLOP/byte(fp32)
- AI = 1 / 6 ≈ 0.17 FLOP/byte(fp16/bf16)
A100 fp16 tensor core 312 TFLOPS、HBM 2039 GB/s 的 ridge point = 312e12 / 2039e9 ≈ 153 FLOP/byte——向量加法的 AI 比这低三个数量级。
结论:在 roofline 上向量加法永远卡在
peak_BW × AI这条斜线上,即便算力翻 100 倍它也快不了。"不要妄想 Triton 写 elementwise 能超越 cuBLAS / PyTorch"——你只能争取打满带宽。
详细的 roofline 分析见第 5.x 节。
4.7.4 与 PyTorch 内置实现的微小差距来自哪里
实测中 Triton 和 torch.add 通常差 0~3%,差距来源大致这几条:
- Kernel launch overhead:每次 launch 大约 5~10 μs。小输入(如 N=4096,耗时仅 1 μs 级)launch 开销占比可达 80%,大输入趋近 0。
- JIT 编译开销:首次调用 Triton 要解析、降级、
ptxas、cache。约 100 ms ~ 几秒(取决于配置数)。do_bench 通常会预热掉。 - PyTorch 走 cuBLAS/CUDA core 高度优化的 elementwise kernel:底层是 NVIDIA 工程师手写多年调过的 PTX,常见会用
LDGSTS.E.128等硬件指令。Triton 编译产物可能差一两条指令。 - block size 选择:如果你 hardcode
BLOCK_SIZE=1024,对于极小输入是浪费、极大输入是不够;PyTorch heuristic 会按 size 切。autotune 后差距通常消失。
一个意外结论
Triton 在中等大小输入(1M ~ 16M 元素)上经常比 torch.add 快 1~5%。原因是 PyTorch 的 dispatch 路径要走多层 Python → C++ → CUDA dispatcher,单次 launch 开销实际比 Triton 高。这点在写 fused kernel 时尤其重要——融合 5 个 elementwise op 进一个 Triton kernel,省掉 4 次 launch overhead 就是稳定的 1.2~2× 加速。
4.8 从向量加法看 Triton 编译产物
Triton 让你"假装在写 Python",但理解它编出了什么是从中级走向高级的分水岭。本节展示如何 dump 出 PTX、读出关键指令。
4.8.1 拿到 PTX 的最短代码
import torch, triton
# ... add_kernel 定义如前 ...
x = torch.rand(1024, device='cuda')
y = torch.rand(1024, device='cuda')
out = torch.empty_like(x)
# 触发一次 JIT,拿到 compiled handle
compiled = add_kernel[(1,)](x, y, out, 1024, BLOCK_SIZE=1024)
print(compiled.asm['ptx']) # PTX 汇编
# print(compiled.asm['ttgir']) # TTGIR (带 layout 的块级 IR)
# print(compiled.asm['llir']) # LLVM IR把输出保存到文件后,可以用 cuobjdump --dump-sass 反汇编 cubin 看到真正的 SASS:
ls ~/.triton/cache/<hash>/*.cubin | head -1
cuobjdump --dump-sass <path>.cubin > add_kernel.sass4.8.2 PTX 关键指令解读
向量加法的 PTX 大致长这样(节选,已简化):
PTX 片段
//---- 算 pid 和 offsets ----
mov.u32 %r1, %ctaid.x; // pid
shl.b32 %r2, %r1, 10; // pid * 1024 (BLOCK_SIZE)
//---- 算 thread 在 block 内的位置 ----
mov.u32 %r3, %tid.x;
add.s32 %r4, %r2, %r3; // base offset for this thread
//---- 向量化加载 x ----
mul.wide.s32 %rd1, %r4, 4; // offset * sizeof(fp32)
add.s64 %rd2, %rd_x, %rd1; // x_ptr + offset
ld.global.v4.f32 {%f1,%f2,%f3,%f4}, [%rd2]; // 一次取 4 个 fp32 = 16 B
//---- 向量化加载 y ----
add.s64 %rd3, %rd_y, %rd1;
ld.global.v4.f32 {%f5,%f6,%f7,%f8}, [%rd3];
//---- 向量加法 ----
add.f32 %f9, %f1, %f5;
add.f32 %f10, %f2, %f6;
add.f32 %f11, %f3, %f7;
add.f32 %f12, %f4, %f8;
//---- 向量化写回 ----
add.s64 %rd4, %rd_out, %rd1;
st.global.v4.f32 [%rd4], {%f9,%f10,%f11,%f12};要点解读:
| 指令 | 含义 | 看到它意味着什么 |
|---|---|---|
ld.global.v4.f32 | 一条指令加载 4 个 fp32 = 16 B | 完美向量化,coalesce 拿满 |
ld.global.f32 | 一条指令只加载 1 个 fp32 = 4 B | 没向量化,IO 指令数翻 4 倍 |
ld.global.nc.v4.f32 | 走只读 (non-coherent) 路径 | 读-only 数据走 read-only cache,省 L1 槽位 |
add.f32 × 4 | 标量加法 4 条 | fp32 没有 SIMD add;fp16 才会出现 add.f16x2 |
st.global.v4.f32 | 16 B 向量化写回 | 写也 coalesce 了 |
看到 ld.global.f32(非 v4)就要警觉
- 可能 BLOCK_SIZE 太小、单 thread 只摊到 1 个元素 → 增大 BLOCK_SIZE
- 可能 dtype 是 fp64 + BLOCK 不够大 → 增大 BLOCK 至 ≥ 4 × num_warps × 32 = 512
- 可能 offsets 是 gather 模式(不连续)→ 重整数据布局
4.8.3 num_warps 对生成代码的影响
num_warps 控制单个 program 内部用多少个 warp 执行——直接决定每 thread 要扛多少元素。同一个 BLOCK_SIZE=1024 在不同 num_warps 下的对比:
| num_warps | threads/program | elements/thread (fp32) | 典型生成 |
|---|---|---|---|
| 1 | 32 | 32 | 8 条 ld.global.v4.f32(每 thread 8 次 16 B 加载) |
| 2 | 64 | 16 | 4 条 ld.global.v4.f32 |
| 4 (default) | 128 | 8 | 2 条 ld.global.v4.f32 |
| 8 | 256 | 4 | 1 条 ld.global.v4.f32 |
| 16 | 512 | 2 | 1 条 ld.global.v2.f32(只剩 8 B 向量) |
| 32 | 1024 | 1 | 1 条 ld.global.f32(完全标量化) |
发现什么了?num_warps 太大反而会丢失向量化机会。num_warps=32 时单 thread 只剩 1 个元素,根本无法向量化,IO 指令翻 4 倍。
经验法则:
- 1D elementwise:
num_warps = max(1, min(BLOCK_SIZE / 256, 8)),即让每 thread 至少分到 4 个元素。 - 2D/matmul:
num_warps = 4 或 8,由 autotune 决定。 - 永远把 num_warps 加进 autotune 搜索空间:它和 BLOCK_SIZE 是耦合的,单独调一个是徒劳。
Triton 编译器何时自动向量化
TTGIR 阶段有个 LoadVectorizer pass:它扫描 tt.load 的指针向量是否连续 (contiguous),连续就把 sizePerThread 提升到 4(即 .v4 指令)。这个 pass 默认就开,你不需要任何注解——但 sizePerThread 的上限是 16 B (fp32 v4 / fp16 v8),超过 BLOCK 大小时会回落到更小宽度。
本章小结
- 一个 Triton 核函数的通用骨架只有四步:算 offsets → 加 mask → load → compute → store。
@triton.jit把 Python 函数变成可被 JIT 编译的 GPU 核函数;不同的 constexpr 值触发不同的特化编译。tl.load/tl.store是与 HBM 唯一的交换通道;让offsets连续就能享受自动内存合并。- 永远加 mask,并且用非整除尺寸做测试,防止 mask bug 被"看似正确"的整除尺寸掩盖。
- 向量加法 Triton 与 PyTorch 性能持平是预期;Triton 的真正价值会在第 7 章(算子融合)和第 8 章(FlashAttention)显现。
至此第一部分"基础篇"结束。接下来我们进入"性能篇"——第 5 章先讲清楚内存优化的硬件根基,让你明白为什么"减少 HBM 往返"是 Triton 调优的第一性原理。
思考题
- 把
add_kernel改写成 SiLU 激活函数y = x * sigmoid(x)。你需要改动几行代码?核函数的输入参数有什么变化? - 假设
BLOCK_SIZE=1024、n_elements=2000,会启动多少个程序实例?最后一个程序实例的offsets是什么?mask 中有多少个 True、多少个 False? - 如果把 wrapper 里的
grid = lambda meta: (...)改成grid = (triton.cdiv(n_elements, 1024),)(直接传元组),代码还能跑吗?两种写法各有什么优缺点?什么场景下必须用 lambda?