"""
示例 01：向量加法 (Vector Addition)
==================================

这是 Triton 入门的最经典示例，对应官方教程的 01-vector-add。

本示例展示：
    1. 如何用 @triton.jit 把一个普通 Python 函数编译为 GPU kernel
    2. tl.program_id：在 grid 中定位当前 program
    3. tl.arange + 指针算术：构造块级偏移
    4. tl.load / tl.store：带 mask 的边界安全读写
    5. 用 lambda 动态计算 grid 尺寸
    6. 与 PyTorch 对比验证正确性
    7. 用 triton.testing 做带宽基准测试

运行：
    python 01_vector_add.py

参考：
    https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html
"""

import torch
import triton
import triton.language as tl


# 自动选择当前可用的 GPU 设备（CUDA / ROCm 通用写法）
DEVICE = triton.runtime.driver.active.get_active_torch_device()


# -----------------------------------------------------------------------------
# 1. Triton kernel 定义
# -----------------------------------------------------------------------------
@triton.jit
def add_kernel(
    x_ptr,                    # 输入张量 x 的指针（指向首元素）
    y_ptr,                    # 输入张量 y 的指针
    output_ptr,               # 输出张量的指针
    n_elements,               # 向量总长度（运行期标量）
    BLOCK_SIZE: tl.constexpr, # 每个 program 处理多少元素；编译期常量，必须是 2 的幂
):
    """
    向量加法的 Triton kernel。

    Triton 的 SPMD 模型：grid 中每个 program (类似 CUDA 的 block) 处理一个连续
    BLOCK_SIZE 长度的数据切片。program 内部由编译器自动把工作分发到 warp / 线程。
    开发者完全看不到 threadIdx，专注于"块"这一抽象。
    """
    # ---- 步骤 1：拿到当前 program 在 1D grid 中的 ID ----
    # 类比 CUDA 的 blockIdx.x，但 program 一次处理一整块 BLOCK_SIZE 元素
    pid = tl.program_id(axis=0)

    # ---- 步骤 2：计算本 program 负责的全局偏移 ----
    # block_start: 当前块在整个向量中的起始下标
    # offsets: 形如 [block_start, block_start+1, ..., block_start+BLOCK_SIZE-1]
    # 这是一个 Triton tensor，后面所有运算自动按向量执行
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # ---- 步骤 3：构造越界 mask ----
    # n_elements 不一定是 BLOCK_SIZE 的整数倍，最后一个 program 可能跨界
    # mask 为 False 的位置不会触发 DRAM 访问，从而避免段错误
    mask = offsets < n_elements

    # ---- 步骤 4：从 HBM 加载两个输入块 ----
    # 指针 + 整数向量 = 指针向量；load 一次性把整个 block 拉进寄存器 / SRAM
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # ---- 步骤 5：在寄存器内完成逐元素加法 ----
    output = x + y

    # ---- 步骤 6：把结果写回 HBM，mask 处不写入 ----
    tl.store(output_ptr + offsets, output, mask=mask)


# -----------------------------------------------------------------------------
# 2. Python 端的 wrapper 函数
# -----------------------------------------------------------------------------
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Triton 版本的向量加法，签名与 torch.add 等价。

    主要做四件事：
        a) 预分配输出张量
        b) 用 lambda 计算 grid（让 BLOCK_SIZE 变化时 grid 自动适配）
        c) 启动 kernel（异步）
        d) 返回 output；后续 PyTorch 使用它时会自动同步
    """
    assert x.is_cuda and y.is_cuda, "输入张量必须在 GPU 上"
    assert x.shape == y.shape, f"形状不匹配：{x.shape} vs {y.shape}"

    output = torch.empty_like(x)
    n_elements = output.numel()

    # grid 用 lambda 延迟求值，这样 BLOCK_SIZE 变了不需要改 grid 写法
    # meta 是一个 dict，包含所有 constexpr 参数
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    # 启动语法：kernel[grid](*args, **constexpr_kwargs)
    # torch.Tensor 会被自动解释为指向其 data_ptr() 的指针
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    # 注意：kernel launch 是异步的；这里直接返回，下次 .item() / 打印时会自动同步
    return output


# -----------------------------------------------------------------------------
# 3. 正确性验证
# -----------------------------------------------------------------------------
def test_correctness():
    """与 PyTorch 内置加法对比，验证两者数值一致。"""
    torch.manual_seed(0)

    # 故意选一个不是 BLOCK_SIZE (1024) 整数倍的尺寸，测试 mask 是否正确
    size = 98_432
    x = torch.rand(size, device=DEVICE)
    y = torch.rand(size, device=DEVICE)

    output_torch = x + y
    output_triton = add(x, y)

    max_diff = torch.max(torch.abs(output_torch - output_triton)).item()
    print(f"[正确性] size={size}")
    print(f"  torch  前 5 个: {output_torch[:5].tolist()}")
    print(f"  triton 前 5 个: {output_triton[:5].tolist()}")
    print(f"  最大绝对误差:   {max_diff:.3e}")
    assert max_diff < 1e-6, "Triton 与 torch 结果差异过大"
    print("  [PASS]\n")


# -----------------------------------------------------------------------------
# 4. 性能基准测试（带宽对比）
# -----------------------------------------------------------------------------
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],                          # 横轴变量
        x_vals=[2 ** i for i in range(12, 28, 1)], # 4K ~ 128M 个元素
        x_log=True,
        line_arg='provider',                       # 同一张图上对比谁
        line_vals=['triton', 'torch'],
        line_names=['Triton', 'Torch'],
        styles=[('blue', '-'), ('green', '-')],
        ylabel='GB/s',
        plot_name='vector-add-performance',
        args={},
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device=DEVICE, dtype=torch.float32)
    y = torch.rand(size, device=DEVICE, dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]  # 中位数 + 20% / 80% 分位数误差棒

    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: x + y, quantiles=quantiles
        )
    else:  # triton
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: add(x, y), quantiles=quantiles
        )

    # 向量加法的有效带宽：读两个输入 + 写一个输出 = 3 倍元素字节数
    gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


# -----------------------------------------------------------------------------
# 5. 主入口
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    print("=" * 60)
    print("Triton 示例 01：向量加法")
    print("=" * 60)

    test_correctness()

    print("[Benchmark] 跑带宽对比（结果以 GB/s 为单位，越高越好）")
    print("说明：向量加法是 memory-bound 算子，Triton 通常贴近 PyTorch")
    print("      (PyTorch 底层也是高度优化的 CUDA / cuBLAS)")
    print("-" * 60)
    benchmark.run(print_data=True, show_plots=False)
