"""
示例 02：融合 Softmax (Fused Softmax)
====================================

这是 Triton 算子融合 (kernel fusion) 的经典演示，对应官方教程的 02-fused-softmax。

朴素 PyTorch softmax 的内存开销分析（设输入 x 形状为 [M, N]）：

    x_max       = x.max(dim=1)        # 读 MN, 写 M
    z           = x - x_max[:, None]  # 读 MN+M, 写 MN
    numerator   = exp(z)              # 读 MN, 写 MN
    denominator = numerator.sum(1)    # 读 MN, 写 M
    ret         = numerator / denom   # 读 MN+M, 写 MN
    -------------------------------------------
    合计：读 5MN + 2M, 写 3MN + 2M

理论下限（融合后）：读 MN + 写 MN = 2MN。因此融合后理论可获得 ~4× 加速。

本示例展示：
    1. 在单个 Triton kernel 内完成 max -> sub -> exp -> sum -> div 全流程
    2. 用 BLOCK_SIZE = next_pow2(n_cols) 把整行装入 SRAM
    3. 用 -inf 作为 mask 填充值，保证 max/exp 不受影响
    4. 数值稳定的 softmax（减去 max 再 exp）
    5. tl.max / tl.sum 等 reduction 操作

限制：要求 n_cols 能装进单个 program 的 SRAM（典型 ≤ 32K）。

运行：
    python 02_fused_softmax.py

参考：
    https://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.html
"""

import torch
import triton
import triton.language as tl


DEVICE = triton.runtime.driver.active.get_active_torch_device()


# -----------------------------------------------------------------------------
# 1. 朴素 PyTorch 实现（作为正确性 / 性能的对照基线）
# -----------------------------------------------------------------------------
def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    """逐步分解的 softmax，便于看清内存读写次数。"""
    # x: [M, N]
    x_max = x.max(dim=1)[0]               # [M]      读 MN, 写 M
    z = x - x_max[:, None]                # [M, N]   读 MN+M, 写 MN
    numerator = torch.exp(z)              # [M, N]   读 MN, 写 MN
    denominator = numerator.sum(dim=1)    # [M]      读 MN, 写 M
    ret = numerator / denominator[:, None]  # [M, N] 读 MN+M, 写 MN
    return ret


# -----------------------------------------------------------------------------
# 2. Triton 融合 kernel
# -----------------------------------------------------------------------------
@triton.jit
def softmax_kernel(
    output_ptr,                # 输出张量首指针 [M, N]
    input_ptr,                 # 输入张量首指针 [M, N]
    input_row_stride,          # 输入张量第 0 维步长（一行多少元素）
    output_row_stride,         # 输出张量第 0 维步长
    n_rows,                    # M
    n_cols,                    # N
    BLOCK_SIZE: tl.constexpr,  # >= n_cols 的最小 2 的幂
):
    """
    一个 program 负责处理一行（或若干行）。

    设计要点：
        - 每行整体读入寄存器/SRAM，做 max / exp / sum / div 全融合
        - BLOCK_SIZE >= n_cols 保证一行能整块加载
        - 越界位置用 -inf 填充，保证 max 不被污染、exp(-inf)=0 不影响 sum
    """
    # ---- 用循环让一个 program 处理多行，减少 program 启动开销 ----
    # 这种写法被称为 "persistent kernel"：grid 大小固定（= SM 数），
    # 每个 program 在循环中消费多行。
    row_start = tl.program_id(axis=0)
    row_step = tl.num_programs(axis=0)

    for row_idx in tl.range(row_start, n_rows, row_step):
        # ---- 1) 定位本行起始指针 ----
        row_start_ptr = input_ptr + row_idx * input_row_stride

        # ---- 2) 列方向偏移 ----
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets

        # ---- 3) 加载整行，越界用 -inf 填充 ----
        # 为什么是 -inf？
        #   max(-inf, x) = x       → 不影响 max 计算
        #   exp(-inf - max) = 0    → 不贡献到 sum
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))

        # ---- 4) 数值稳定的 softmax，全部在寄存器内完成 ----
        # tl.max 沿 axis=0 对整个 1D 向量做 reduction
        row_minus_max = row - tl.max(row, axis=0)

        # tl.exp 类似 CUDA 的 __expf，快但有少量误差；对训练通常已足够
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator

        # ---- 5) 写回 ----
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)


# -----------------------------------------------------------------------------
# 3. Python 端 wrapper
# -----------------------------------------------------------------------------
def softmax(x: torch.Tensor) -> torch.Tensor:
    """
    沿最后一维 (dim=1) 做 softmax，等价于 torch.softmax(x, dim=1)。
    """
    assert x.is_cuda, "输入张量必须在 GPU 上"
    assert x.dim() == 2, "本示例只处理 2D 张量"

    n_rows, n_cols = x.shape

    # BLOCK_SIZE 必须是 2 的幂；这里取 >= n_cols 的最小 2 的幂
    # 当 n_cols 较大时，可能需要更多寄存器 / 更大共享内存
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # num_warps：单 program 用多少 warp，越大并行度越高但寄存器越紧张
    # 经验值：
    #   n_cols <= 2048 → 4 warps
    #   n_cols <= 4096 → 8 warps
    #   更大          → 16 warps
    num_warps = 4
    if BLOCK_SIZE >= 2048:
        num_warps = 8
    if BLOCK_SIZE >= 4096:
        num_warps = 16

    y = torch.empty_like(x)

    # Grid 大小直接取 n_rows，但每个 program 用循环消费多行更省启动开销
    # 简化版：grid = (n_rows,)，每个 program 只处理一行
    # 完整 persistent 版：grid = (NUM_SM,)，配合循环步长 num_programs
    grid = (n_rows,)

    softmax_kernel[grid](
        y, x,
        x.stride(0), y.stride(0),
        n_rows, n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )
    return y


# -----------------------------------------------------------------------------
# 4. 正确性验证
# -----------------------------------------------------------------------------
def test_correctness():
    torch.manual_seed(0)

    # 故意选 n_cols=781（不是 2 的幂），测试 BLOCK_SIZE padding + mask
    x = torch.randn(1823, 781, device=DEVICE)

    y_triton = softmax(x)
    y_torch = torch.softmax(x, dim=1)
    y_naive = naive_softmax(x)

    max_diff_torch = torch.max(torch.abs(y_triton - y_torch)).item()
    max_diff_naive = torch.max(torch.abs(y_triton - y_naive)).item()

    print(f"[正确性] shape=(1823, 781)")
    print(f"  vs torch.softmax  最大绝对误差: {max_diff_torch:.3e}")
    print(f"  vs naive_softmax  最大绝对误差: {max_diff_naive:.3e}")

    # softmax 输出每行应和为 1
    row_sums = y_triton.sum(dim=1)
    print(f"  每行和的均值: {row_sums.mean().item():.6f} (理论 1.0)")
    print(f"  每行和的最大偏差: {torch.max(torch.abs(row_sums - 1.0)).item():.3e}")

    assert max_diff_torch < 1e-5, "Triton softmax 与 torch 差异过大"
    print("  [PASS]\n")


# -----------------------------------------------------------------------------
# 5. 性能基准测试
# -----------------------------------------------------------------------------
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],
        x_vals=[128 * i for i in range(2, 100)],  # 列数从 256 扫到 ~12800
        line_arg='provider',
        line_vals=['triton', 'torch', 'naive'],
        line_names=['Triton (fused)', 'Torch (cuDNN)', 'Naive (jit)'],
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],
        ylabel='GB/s',
        plot_name='softmax-performance',
        args={'M': 4096},  # 固定 M，扫 N
    )
)
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)

    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, dim=1))
    elif provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    else:  # naive
        ms = triton.testing.do_bench(lambda: naive_softmax(x))

    # softmax: 理论上读 1 次 + 写 1 次 = 2 * M * N * 4 bytes
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


# -----------------------------------------------------------------------------
# 6. 主入口
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    print("=" * 60)
    print("Triton 示例 02：融合 Softmax")
    print("=" * 60)

    test_correctness()

    print("[Benchmark] 跑带宽对比 (M=4096, N 从 256 扫到 ~12K)")
    print("预期：Triton fused 接近 HBM 峰值；naive 因未融合慢约 4×")
    print("-" * 60)
    benchmark.run(print_data=True, show_plots=False)
