"""
示例 03：矩阵乘法 (Tiled Matmul with Autotune)
=============================================

这是 Triton 最经典的实战 kernel，对应官方教程的 03-matrix-multiplication。
也是 Triton "用 25 行 Python 写出媲美 cuBLAS 的 GEMM" 这一说法的来源。

算法骨架（C = A @ B，A: [M, K], B: [K, N], C: [M, N]）：

    for m in range(0, M, BLOCK_M):          # 并行（program 维度 1）
        for n in range(0, N, BLOCK_N):      # 并行（program 维度 2，flatten 进 1D grid）
            acc = zeros((BLOCK_M, BLOCK_N), fp32)
            for k in range(0, K, BLOCK_K):  # 顺序累加
                a = A[m:m+BM, k:k+BK]       # [BM, BK]
                b = B[k:k+BK, n:n+BN]       # [BK, BN]
                acc += dot(a, b)            # tl.dot 触发 Tensor Core
            C[m:m+BM, n:n+BN] = acc

本示例展示：
    1. 二维 tiling（每个 program 处理一个 BLOCK_M × BLOCK_N 输出块）
    2. tl.dot 调用 Tensor Core，accumulator 保持 fp32 防止精度损失
    3. @triton.autotune 自动在多个 (BLOCK_*, num_warps, num_stages) 组合中选最优
    4. L2 cache swizzling：grouped program ordering 提升 ~10% 性能
    5. 用 stride 参数支持任意 layout（行主序、列主序、转置）
    6. 与 torch.matmul (cuBLAS) 对比

运行：
    python 03_matmul.py

注意：autotune 在首次调用时会跑遍所有 config，可能耗时数秒到数十秒；
      之后命中缓存就是零开销。

参考：
    https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
"""

import torch
import triton
import triton.language as tl


DEVICE = triton.runtime.driver.active.get_active_torch_device()


# -----------------------------------------------------------------------------
# 1. autotune 配置候选
# -----------------------------------------------------------------------------
# 每个 Config 描述一组 (meta-parameters, num_warps, num_stages)
# autotune 在首次特定 (M, N, K) 组合调用时会跑遍所有 config，记下最快的一个
#
# 参数含义：
#   BLOCK_SIZE_M/N/K  : tile 尺寸，决定 SRAM 占用 + Tensor Core 调度效率
#   GROUP_SIZE_M      : L2 swizzling 分组大小（见 kernel 第 1 节）
#   num_warps         : 每个 program 用多少 warp（每 warp = 32 线程）
#   num_stages        : 软件流水线深度，越深越能隐藏 DRAM 延迟，但占用更多 SRAM
#                       NVIDIA 常用 2~5；AMD 推荐 0~1
def get_autotune_configs():
    return [
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64,
             'GROUP_SIZE_M': 8},
            num_stages=3, num_warps=8),
        triton.Config(
            {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
        triton.Config(
            {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32,
             'GROUP_SIZE_M': 8},
            num_stages=4, num_warps=4),
    ]


# -----------------------------------------------------------------------------
# 2. Triton kernel
# -----------------------------------------------------------------------------
@triton.autotune(
    configs=get_autotune_configs(),
    key=['M', 'N', 'K'],  # 仅当 (M,N,K) 变化时才重新调优
)
@triton.jit
def matmul_kernel(
    # 数据指针
    a_ptr, b_ptr, c_ptr,
    # 形状
    M, N, K,
    # 各张量的 stride（用 stride 而不是写死 layout，可以支持转置等）
    stride_am, stride_ak,   # A: 每行多少元素 / 每列多少元素
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # tile 尺寸（编译期常量）
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    """
    每个 program 计算 C 的一个 [BLOCK_SIZE_M, BLOCK_SIZE_N] 输出块。
    """
    # ---- 1) Grouped program ordering: 优化 L2 cache 复用 ----
    #
    # 朴素行主序：相邻 program 沿 N 方向扫，需要把整列 B 反复装入 L2
    #
    # Grouped ordering：把每 GROUP_SIZE_M 行块视为一个"组"，组内按列主序
    # 这样相邻 program 共享 A 的 tile，组之间共享 B 的 tile
    # 官方文档报告：A100 上 220 → 245 TFLOPS (+11%)
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    # 最后一个 group 可能不满 GROUP_SIZE_M 行
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ---- 2) 构造 A、B 的初始块指针 ----
    #
    # A tile 形状: [BLOCK_SIZE_M, BLOCK_SIZE_K]
    # B tile 形状: [BLOCK_SIZE_K, BLOCK_SIZE_N]
    #
    # 用 % M / % N 对越界行/列做"折回"——loaded data 不重要，反正写回时
    # 用 mask 截掉。这样可以让内层 K 循环不再需要 mask，少一组 if。
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    # 用广播构造 2D 指针块
    # a_ptrs[i, j] = a_ptr + offs_am[i] * stride_am + offs_k[j] * stride_ak
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # ---- 3) K 维主循环：分块累加 ----
    #
    # 关键点：accumulator 用 fp32！即使输入是 fp16，累加器也保持 fp32 精度，
    # 这是所有 GPU GEMM 库（cuBLAS / CUTLASS / Triton）的标配做法。
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # K 维 tail mask：当 K 不是 BLOCK_SIZE_K 整数倍时，最后一块要截掉
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

        # tl.dot 触发 Tensor Core（输入 fp16/bf16 → acc fp32）
        # 第三个参数是初始累加值，原地累加避免寄存器搬运
        accumulator = tl.dot(a, b, accumulator)

        # 沿 K 方向滑动指针
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    # ---- 4) 把 fp32 累加结果转回 fp16 ----
    # 注意：如果输入是 fp32，此处应该 .to(tl.float32)（其实就是 no-op）
    c = accumulator.to(tl.float16)

    # ---- 5) 带 mask 写回 C ----
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


# -----------------------------------------------------------------------------
# 3. Python wrapper
# -----------------------------------------------------------------------------
def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    Triton 版矩阵乘法，签名等价于 torch.matmul(a, b)。
    输入: a [M, K], b [K, N]
    输出: c [M, N] (fp16)
    """
    assert a.is_cuda and b.is_cuda
    assert a.shape[1] == b.shape[0], f"形状不匹配：{a.shape} @ {b.shape}"
    assert a.dtype == b.dtype, "两个输入 dtype 必须一致"

    M, K = a.shape
    _, N = b.shape

    # 预分配输出（用 fp16 节省带宽 / 显存；累加器仍是 fp32）
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)

    # 1D grid：把 (pid_m, pid_n) 平铺到一维
    # 实际的二维分块在 kernel 内通过 grouped ordering 还原
    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
    )

    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c


# -----------------------------------------------------------------------------
# 4. 正确性验证
# -----------------------------------------------------------------------------
def test_correctness():
    torch.manual_seed(0)

    # 故意选择不是 BLOCK_SIZE 整数倍的形状，测试边界 mask
    M, N, K = 512, 384, 256
    a = torch.randn(M, K, device=DEVICE, dtype=torch.float16)
    b = torch.randn(K, N, device=DEVICE, dtype=torch.float16)

    c_triton = matmul(a, b)
    c_torch = torch.matmul(a, b)

    # fp16 GEMM 误差较大，常用 rtol/atol 比较而不是绝对相等
    max_diff = torch.max(torch.abs(c_triton.float() - c_torch.float())).item()
    rel_diff = max_diff / torch.max(torch.abs(c_torch)).item()

    print(f"[正确性] shape={M}x{K} @ {K}x{N}")
    print(f"  最大绝对误差: {max_diff:.3e}")
    print(f"  最大相对误差: {rel_diff:.3e}")

    # fp16 累加 fp32 通常 1e-2 相对误差以内
    assert torch.allclose(c_triton, c_torch, atol=1e-2, rtol=1e-2), \
        "Triton matmul 与 torch.matmul 差异过大"
    print("  [PASS]\n")

    # 再测一个更大、更接近实际工作负载的 shape
    M, N, K = 1024, 1024, 1024
    a = torch.randn(M, K, device=DEVICE, dtype=torch.float16)
    b = torch.randn(K, N, device=DEVICE, dtype=torch.float16)
    c_triton = matmul(a, b)
    c_torch = torch.matmul(a, b)
    print(f"[正确性] shape={M}x{K} @ {K}x{N}")
    print(f"  最大绝对误差: "
          f"{torch.max(torch.abs(c_triton.float() - c_torch.float())).item():.3e}")
    assert torch.allclose(c_triton, c_torch, atol=1e-2, rtol=1e-2)
    print("  [PASS]\n")


# -----------------------------------------------------------------------------
# 5. 性能基准测试
# -----------------------------------------------------------------------------
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['M', 'N', 'K'],
        # 方阵尺寸：256, 512, 1024, ..., 8192
        x_vals=[128 * i for i in range(2, 33)],
        line_arg='provider',
        line_vals=['triton', 'torch'],
        line_names=['Triton', 'Torch (cuBLAS)'],
        styles=[('blue', '-'), ('green', '-')],
        ylabel='TFLOPS',
        plot_name='matmul-performance-fp16',
        args={},
    )
)
def benchmark(M, N, K, provider):
    a = torch.randn(M, K, device=DEVICE, dtype=torch.float16)
    b = torch.randn(K, N, device=DEVICE, dtype=torch.float16)
    quantiles = [0.5, 0.2, 0.8]

    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: torch.matmul(a, b), quantiles=quantiles
        )
    else:  # triton
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: matmul(a, b), quantiles=quantiles
        )

    # matmul 算力：2 * M * N * K FLOPs（每个输出位置 K 次乘加 = 2K FLOPs）
    perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)


# -----------------------------------------------------------------------------
# 6. 主入口
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    print("=" * 60)
    print("Triton 示例 03：矩阵乘法（with autotune）")
    print("=" * 60)

    test_correctness()

    print("[Benchmark] 跑 TFLOPS 对比（方阵 256~4096）")
    print("预期：在 A100/H100 上，Triton 可达 cuBLAS 的 90~100%")
    print("注意：autotune 在每个新 (M,N,K) 上首次调用时会跑遍所有 config")
    print("      因此 benchmark 总时长会比较长")
    print("-" * 60)
    benchmark.run(print_data=True, show_plots=False)
