Skip to content

4. 基础算子:从向量加法开始

本章你将动手写出第一个真正能跑的 Triton 核函数。我们以向量加法为例,把第 3 章学到的概念全部落地@triton.jittl.program_idtl.arangetl.load/store、网格启动、与 PyTorch 对照验证、做性能 benchmark。

4.1 任务定义

最简单的逐元素运算:给定两个长度为 N 的向量 xy,计算 output = x + y

python
import torch

x = torch.rand(1_000_000, device='cuda')
y = torch.rand(1_000_000, device='cuda')
output = x + y      # PyTorch 一行搞定

PyTorch 已经够快了,那为什么我们要费劲用 Triton 重写一遍

  • 这是最经典的入门示例,对应 Triton 官方教程 01-vector-add,能完整覆盖 Triton 的核心 API
  • 它揭示了 Triton 写所有逐元素算子的通用骨架——后续的 LayerNorm、softmax、激活函数都是它的变体

期望管理

向量加法是完全 memory-bound 的算子(每个元素只做一次加法,带宽几乎决定一切)。Triton 写出来的版本不会比 PyTorch 更快,能持平就是成功。它的价值在于教学,不在于性能。

4.2 @triton.jit 装饰器

python
import triton
import triton.language as tl

@triton.jit
def my_kernel(x_ptr, ..., BLOCK_SIZE: tl.constexpr):
    ...

@triton.jit 做了三件事:

  1. 标记这个 Python 函数将被编译为 GPU 核函数
  2. 解析函数体为 Triton 子集语法(不是普通 Python——你不能在里面 print()、不能 import、不能用任意 Python 库)
  3. JIT 编译:首次以特定 (dtype, constexpr 值, ...) 组合调用时,把它编译成针对该组合特化的 PTX/CUBIN;结果缓存到 ~/.triton/cache/

不同的 constexpr 值 = 不同的编译产物

如果你用 BLOCK_SIZE=1024BLOCK_SIZE=2048 各调一次,会触发两次JIT 编译,得到两个独立的二进制。这是 Triton 高性能的关键——每种配置都是单独特化的,没有运行时分支。

4.3 tl.load / tl.store:与 HBM 交换数据的唯一通道

python
# 加载:把 BLOCK_SIZE 个元素从 HBM 拉进 SRAM/寄存器
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)

# 存储:把 value 写回 HBM
tl.store(out_ptr + offsets, value, mask=mask)

参数:

参数含义
第一个参数指针向量(基址指针 + 偏移整数向量)
maskbool 向量,False 处不读/写
other(仅 load)mask 为 False 处填的默认值,默认 0

性能要点:连续访问

如果 offsets[0, 1, 2, 3, ...] 这样连续的整数序列,编译器会自动生成"向量化加载"指令(一条 PTX 指令搬运多个元素),充分利用 HBM 带宽——这就是"自动内存合并 (coalescing)"。

如果 offsets 是 [0, 100, 200, 300, ...] 这样跳跃的,每个元素要单独发请求,带宽利用率会暴跌。让 offsets 尽量连续是写高性能 Triton 算子的第一守则。

4.4 完整代码:向量加法

下面是一个端到端的、可以直接 python 运行的版本(完整代码见 examples/01_vector_add.py)。我们分步骤拆解。

4.4.1 核函数部分

python
import torch
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

@triton.jit
def add_kernel(
    x_ptr,                    # 输入张量 x 的指针(指向首元素)
    y_ptr,                    # 输入张量 y 的指针
    output_ptr,               # 输出张量的指针
    n_elements,               # 向量总长度(运行期标量)
    BLOCK_SIZE: tl.constexpr, # 每个程序实例处理多少元素;编译期常量,必须是 2 的幂
):
    # ---- 步骤 1:拿到当前程序实例在 1D 网格中的 ID ----
    pid = tl.program_id(axis=0)

    # ---- 步骤 2:计算本程序实例负责的全局偏移 ----
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # ---- 步骤 3:构造越界 mask ----
    mask = offsets < n_elements

    # ---- 步骤 4:从 HBM 加载两个输入块 ----
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # ---- 步骤 5:在寄存器内完成逐元素加法 ----
    output = x + y

    # ---- 步骤 6:把结果写回 HBM ----
    tl.store(output_ptr + offsets, output, mask=mask)

4.4.2 逐步讲解

步骤 1:pid = tl.program_id(axis=0)

每个程序实例拿到自己唯一的 ID。我们用 1D 网格,所以只查 axis=0

N=1000BLOCK_SIZE=128,会启动 ceil(1000/128) = 8 个程序实例,pid 取值范围 [0, 1, 2, ..., 7]

步骤 2:构造 offsets

python
block_start = pid * BLOCK_SIZE       # 标量:本程序实例对应的起始下标
offsets = block_start + tl.arange(0, BLOCK_SIZE)  # 向量:本程序实例要处理的所有下标

举例 pid=2, BLOCK_SIZE=128

block_start = 256
tl.arange(0, 128) = [0, 1, 2, ..., 127]
offsets = [256, 257, 258, ..., 383]

步骤 3:mask

python
mask = offsets < n_elements

最后一个程序实例(pid=7, offsets = [896, ..., 1023])中,下标 1000-1023 都越界。mask 在这些位置为 False,后续 load/store 会自动忽略它们。

步骤 4:load

python
x = tl.load(x_ptr + offsets, mask=mask)
  • x_ptr 是 PyTorch 自动转换给 Triton 的"指向 x 首元素的指针"
  • x_ptr + offsets 得到一个指针向量(128 个指针)
  • tl.load 一次性把 128 个 float 从 HBM 拉到 SRAM/寄存器

步骤 5:加法

python
output = x + y

这是分块级运算——xy 是形状 (128,) 的 Triton tensor,+ 是逐元素加。底层会被编译为多条 SIMD 加法指令。

步骤 6:store

python
tl.store(output_ptr + offsets, output, mask=mask)

mask 在这里同样关键,避免最后一个程序实例把垃圾数据写到 output 之外的内存。

4.4.3 Python wrapper:让核函数像普通函数一样被调用

python
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    assert x.is_cuda and y.is_cuda, "输入张量必须在 GPU 上"
    assert x.shape == y.shape, f"形状不匹配:{x.shape} vs {y.shape}"

    output = torch.empty_like(x)
    n_elements = output.numel()

    # 网格用 lambda 延迟求值
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    # 启动核函数
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    return output

四个关键点:

  1. 预分配输出:Triton 核函数只会"原地写入",不会返回值。要先 torch.empty_like(x) 拿到 output 张量。
  2. 网格用 lambda:好处是改 BLOCK_SIZE 时网格自动适配。meta 是包含所有 constexpr 参数的字典。
  3. triton.cdiv:向上取整除法 = (a + b - 1) // b
  4. 异步启动add_kernel[grid](...) 返回得很快,核函数还在 GPU 上跑。后续 PyTorch 用到 output 时(比如 .item()print())会自动同步。

4.4.4 正确性验证

python
torch.manual_seed(0)
size = 98_432              # 故意选不是 1024 整数倍的尺寸,测 mask
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)

output_torch = x + y
output_triton = add(x, y)

max_diff = torch.max(torch.abs(output_torch - output_triton)).item()
assert max_diff < 1e-6, f"差异过大:{max_diff}"
print(f"[PASS] 最大误差 = {max_diff:.3e}")

永远用"非整除"尺寸测 mask

新手写 Triton 最常见的 bug 就是忘了 mask。但如果你只用 size = 1024 * 100 这种整除尺寸测试,有 bug 也测不出来——因为最后一个程序实例没有越界。故意选 98432、12345 这样的非整除尺寸,才是有效的回归测试。

4.4.5 性能基准:用 triton.testing.do_bench

Triton 自带一套 benchmark 工具,能直接画出性能对比图:

python
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],
        x_vals=[2 ** i for i in range(12, 28, 1)],   # 4K ~ 128M 元素
        x_log=True,
        line_arg='provider',
        line_vals=['triton', 'torch'],
        line_names=['Triton', 'Torch'],
        styles=[('blue', '-'), ('green', '-')],
        ylabel='GB/s',
        plot_name='vector-add-performance',
        args={},
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device=DEVICE, dtype=torch.float32)
    y = torch.rand(size, device=DEVICE, dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]

    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    else:
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)

    # 向量加法的有效带宽 = (读 2 个输入 + 写 1 个输出) × 元素字节数 / 时间
    gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)

benchmark.run(print_data=True, show_plots=False)

在 A100 / H100 这类卡上,你会看到 Triton 和 Torch 两条线几乎重合——这正是 memory-bound 算子的预期。

4.5 这个骨架适用于所有逐元素算子

把上面的核函数中 output = x + y 这一行换成别的,立刻得到一系列新算子:

python
output = x * y                      # 逐元素乘
output = x * x + y * y              # 平方和
output = tl.sigmoid(x)              # sigmoid
output = tl.maximum(x, 0.0)         # ReLU
output = x * tl.sigmoid(x)          # SiLU / Swish
output = tl.where(mask_cond, x, y)  # 条件选择

整个 load → compute → store 的骨架完全不变。这就是 Triton 的威力:一旦掌握骨架,所有逐元素算子都是十分钟内可以写完的小事。

4.6 完整可运行代码

完整代码(含 benchmark 与 ASCII 表格输出)见 examples/01_vector_add.py。运行方式:

bash
cd examples
python 01_vector_add.py

预期输出(A100 为例):

============================================================
Triton 示例 01:向量加法
============================================================
[正确性] size=98432
  torch  前 5 个: [...]
  triton 前 5 个: [...]
  最大绝对误差:   0.000e+00
  [PASS]

[Benchmark] 跑带宽对比(结果以 GB/s 为单位,越高越好)
vector-add-performance:
       size       Triton        Torch
    4096.0      89.7          88.3
    ...
  67108864.0   1413.2        1410.1

4.7 性能分析方法论:把 GB/s 转成"是否打满"

前面跑出的 benchmark 输出"Triton 1413 GB/s、Torch 1410 GB/s"。一个值得训练的肌肉记忆是:看到数字立刻问"距离理论上限还有多远"

4.7.1 算理论带宽上限

向量加法每个元素读 2 次、写 1 次,总 IO 量:

total_bytes = 3 × N × sizeof(element)

测得 kernel 用时 t 秒后:

achieved_BW (GB/s) = total_bytes / t / 1e9
bandwidth_utilization = achieved_BW / peak_HBM_BW

各代 NVIDIA 卡的 peak HBM 带宽(参考值):

GPUHBM 类型Peak BW
V100 32 GBHBM2900 GB/s
A100 40 GBHBM21555 GB/s
A100 80 GB SXMHBM2e2039 GB/s
H100 80 GB SXMHBM3~3350 GB/s
H200 141 GBHBM3e~4800 GB/s

写入 out_ptr 的 1× 不是免费的

完整账:x 读 1×、y 读 1×、out 写 1×,但 GPU 还会在写之前"读一下" out 的对应 cache line(write-allocate 协议),所以一些教科书会用 4×N 而不是 3×N。HBM 的 streaming store 通道大多数情况下能优化掉这次 RFO,所以业内 benchmark 普遍用 3×N,与官方 Triton tutorial 一致。

4.7.2 一个完整的算例

假设 A100 80GB SXM 上跑 N = 64 M (即 67108864) fp32,用 triton.testing.do_bench 测得 t = 0.142 ms

total_bytes = 3 × 67_108_864 × 4 = 805_306_368 B ≈ 805 MB
achieved_BW = 805_306_368 / (0.142e-3) / 1e9 = 5670 GB/s ???

等等——这个数字超过了 A100 的 2039 GB/s peak,说明数据在 L2 cache 里命中,并不是真正打 HBM。要测真实带宽,必须用大于 L2 容量(A100 = 40 MB)的输入:

N_safe = 256_000_000 (256M fp32 = 1 GB,远大于 L2)

实测:A100 上向量加法稳定在 1900~1950 GB/s,bandwidth utilization ≈ 93~95%——这已经是 memory-bound 算子的天花板。

benchmark 易错点

  1. N 必须大于 L2:否则你只是在测 L2 带宽(A100 ~2.2 TB/s、H100 ~5.5 TB/s),与 HBM 带宽差好几倍。
  2. 要预热:第一次调用包含 JIT 编译时间,do_bench 默认会预热但自定义循环要手动 for _ in range(10): kernel(...)
  3. 同步torch.cuda.synchronize() 包在测时器内外的位置错了,会把异步 launch 时间误算进 kernel 时间。

4.7.3 arithmetic intensity:为什么向量加法永远 memory-bound

定义:

AI (FLOP/byte) = 算术运算数 / 实际 IO 字节数

向量加法:

  • FLOP = 1 (一次加法) per element
  • IO = 12 bytes per element (fp32: 3 × 4 B)
  • AI = 1 / 12 ≈ 0.083 FLOP/byte(fp32)
  • AI = 1 / 6 ≈ 0.17 FLOP/byte(fp16/bf16)

A100 fp16 tensor core 312 TFLOPS、HBM 2039 GB/s 的 ridge point = 312e12 / 2039e9 ≈ 153 FLOP/byte——向量加法的 AI 比这低三个数量级。

结论:在 roofline 上向量加法永远卡在 peak_BW × AI 这条斜线上,即便算力翻 100 倍它也快不了。"不要妄想 Triton 写 elementwise 能超越 cuBLAS / PyTorch"——你只能争取打满带宽。

详细的 roofline 分析见第 5.x 节。

4.7.4 与 PyTorch 内置实现的微小差距来自哪里

实测中 Triton 和 torch.add 通常差 0~3%,差距来源大致这几条:

  1. Kernel launch overhead:每次 launch 大约 5~10 μs。小输入(如 N=4096,耗时仅 1 μs 级)launch 开销占比可达 80%,大输入趋近 0。
  2. JIT 编译开销:首次调用 Triton 要解析、降级、ptxas、cache。约 100 ms ~ 几秒(取决于配置数)。do_bench 通常会预热掉。
  3. PyTorch 走 cuBLAS/CUDA core 高度优化的 elementwise kernel:底层是 NVIDIA 工程师手写多年调过的 PTX,常见会用 LDGSTS.E.128 等硬件指令。Triton 编译产物可能差一两条指令。
  4. block size 选择:如果你 hardcode BLOCK_SIZE=1024,对于极小输入是浪费、极大输入是不够;PyTorch heuristic 会按 size 切。autotune 后差距通常消失。

一个意外结论

Triton 在中等大小输入(1M ~ 16M 元素)上经常比 torch.add 快 1~5%。原因是 PyTorch 的 dispatch 路径要走多层 Python → C++ → CUDA dispatcher,单次 launch 开销实际比 Triton 高。这点在写 fused kernel 时尤其重要——融合 5 个 elementwise op 进一个 Triton kernel,省掉 4 次 launch overhead 就是稳定的 1.2~2× 加速。

4.8 从向量加法看 Triton 编译产物

Triton 让你"假装在写 Python",但理解它编出了什么是从中级走向高级的分水岭。本节展示如何 dump 出 PTX、读出关键指令。

4.8.1 拿到 PTX 的最短代码

python
import torch, triton
# ... add_kernel 定义如前 ...

x = torch.rand(1024, device='cuda')
y = torch.rand(1024, device='cuda')
out = torch.empty_like(x)

# 触发一次 JIT,拿到 compiled handle
compiled = add_kernel[(1,)](x, y, out, 1024, BLOCK_SIZE=1024)

print(compiled.asm['ptx'])     # PTX 汇编
# print(compiled.asm['ttgir']) # TTGIR (带 layout 的块级 IR)
# print(compiled.asm['llir'])  # LLVM IR

把输出保存到文件后,可以用 cuobjdump --dump-sass 反汇编 cubin 看到真正的 SASS:

bash
ls ~/.triton/cache/<hash>/*.cubin | head -1
cuobjdump --dump-sass <path>.cubin > add_kernel.sass

4.8.2 PTX 关键指令解读

向量加法的 PTX 大致长这样(节选,已简化):

PTX 片段
ptx
//---- 算 pid 和 offsets ----
mov.u32      %r1, %ctaid.x;           // pid
shl.b32      %r2, %r1, 10;            // pid * 1024  (BLOCK_SIZE)
//---- 算 thread 在 block 内的位置 ----
mov.u32      %r3, %tid.x;
add.s32      %r4, %r2, %r3;           // base offset for this thread

//---- 向量化加载 x ----
mul.wide.s32  %rd1, %r4, 4;           // offset * sizeof(fp32)
add.s64       %rd2, %rd_x, %rd1;      // x_ptr + offset
ld.global.v4.f32 {%f1,%f2,%f3,%f4}, [%rd2];   // 一次取 4 个 fp32 = 16 B

//---- 向量化加载 y ----
add.s64       %rd3, %rd_y, %rd1;
ld.global.v4.f32 {%f5,%f6,%f7,%f8}, [%rd3];

//---- 向量加法 ----
add.f32       %f9,  %f1, %f5;
add.f32       %f10, %f2, %f6;
add.f32       %f11, %f3, %f7;
add.f32       %f12, %f4, %f8;

//---- 向量化写回 ----
add.s64       %rd4, %rd_out, %rd1;
st.global.v4.f32 [%rd4], {%f9,%f10,%f11,%f12};

要点解读:

指令含义看到它意味着什么
ld.global.v4.f32一条指令加载 4 个 fp32 = 16 B完美向量化,coalesce 拿满
ld.global.f32一条指令只加载 1 个 fp32 = 4 B没向量化,IO 指令数翻 4 倍
ld.global.nc.v4.f32走只读 (non-coherent) 路径读-only 数据走 read-only cache,省 L1 槽位
add.f32 × 4标量加法 4 条fp32 没有 SIMD add;fp16 才会出现 add.f16x2
st.global.v4.f3216 B 向量化写回写也 coalesce 了

看到 ld.global.f32(非 v4)就要警觉

  • 可能 BLOCK_SIZE 太小、单 thread 只摊到 1 个元素 → 增大 BLOCK_SIZE
  • 可能 dtype 是 fp64 + BLOCK 不够大 → 增大 BLOCK 至 ≥ 4 × num_warps × 32 = 512
  • 可能 offsets 是 gather 模式(不连续)→ 重整数据布局

4.8.3 num_warps 对生成代码的影响

num_warps 控制单个 program 内部用多少个 warp 执行——直接决定每 thread 要扛多少元素。同一个 BLOCK_SIZE=1024 在不同 num_warps 下的对比:

num_warpsthreads/programelements/thread (fp32)典型生成
132328 条 ld.global.v4.f32(每 thread 8 次 16 B 加载)
264164 条 ld.global.v4.f32
4 (default)12882 条 ld.global.v4.f32
825641 条 ld.global.v4.f32
1651221 条 ld.global.v2.f32(只剩 8 B 向量)
32102411 条 ld.global.f32完全标量化

发现什么了?num_warps 太大反而会丢失向量化机会num_warps=32 时单 thread 只剩 1 个元素,根本无法向量化,IO 指令翻 4 倍。

经验法则:

  • 1D elementwise:num_warps = max(1, min(BLOCK_SIZE / 256, 8)),即让每 thread 至少分到 4 个元素。
  • 2D/matmul:num_warps = 4 或 8,由 autotune 决定。
  • 永远把 num_warps 加进 autotune 搜索空间:它和 BLOCK_SIZE 是耦合的,单独调一个是徒劳。

Triton 编译器何时自动向量化

TTGIR 阶段有个 LoadVectorizer pass:它扫描 tt.load 的指针向量是否连续 (contiguous),连续就把 sizePerThread 提升到 4(即 .v4 指令)。这个 pass 默认就开,你不需要任何注解——但 sizePerThread 的上限是 16 B (fp32 v4 / fp16 v8),超过 BLOCK 大小时会回落到更小宽度。

本章小结

  • 一个 Triton 核函数的通用骨架只有四步:算 offsets → 加 mask → load → compute → store
  • @triton.jit 把 Python 函数变成可被 JIT 编译的 GPU 核函数;不同的 constexpr 值触发不同的特化编译。
  • tl.load / tl.store 是与 HBM 唯一的交换通道;让 offsets 连续就能享受自动内存合并。
  • 永远加 mask,并且用非整除尺寸做测试,防止 mask bug 被"看似正确"的整除尺寸掩盖。
  • 向量加法 Triton 与 PyTorch 性能持平是预期;Triton 的真正价值会在第 7 章(算子融合)和第 8 章(FlashAttention)显现。

至此第一部分"基础篇"结束。接下来我们进入"性能篇"——第 5 章先讲清楚内存优化的硬件根基,让你明白为什么"减少 HBM 往返"是 Triton 调优的第一性原理。

思考题

  1. add_kernel 改写成 SiLU 激活函数 y = x * sigmoid(x)。你需要改动几行代码?核函数的输入参数有什么变化?
  2. 假设 BLOCK_SIZE=1024n_elements=2000,会启动多少个程序实例?最后一个程序实例的 offsets 是什么?mask 中有多少个 True、多少个 False?
  3. 如果把 wrapper 里的 grid = lambda meta: (...) 改成 grid = (triton.cdiv(n_elements, 1024),)(直接传元组),代码还能跑吗?两种写法各有什么优缺点?什么场景下必须用 lambda?

基于 MIT 协议发布