"""
示例 04：FlashAttention 简化版
==============================

这是 Triton 高级编程的"集大成"示例，演示 FlashAttention v2 的核心思想。
对应官方教程 06-fused-attention 的简化版（去掉 causal / warp_specialize / TMA
等高级特性，专注于讲清算法本身）。

背景
----
标准 attention：

    O = softmax(QK^T / sqrt(d)) @ V

其中 Q, K, V 形状均为 [N, d]（N=seq_len, d=head_dim）。

朴素实现的问题：
    1. 必须 materialize 完整的 attention matrix S = QK^T，形状 [N, N]
    2. 当 N=4096 时单 head 就要 32 MB (fp16)，远超 SRAM 容量
    3. 整个 N×N 矩阵在 HBM 上来回搬运多次（max / sub / exp / sum / matmul）
    4. 显存随 N² 增长，长序列直接 OOM

FlashAttention 三大技巧
-----------------------
1. **Tiling**：把 Q、K、V 按 block 切分，永远不 materialize 完整的 S 矩阵
2. **Online softmax**：在 K/V 块循环中用 running max m_i 和 running sum l_i
   渐进式更新 softmax，新块到来时修正旧的累加结果：

       m_new = max(m_old, max(qk_block))
       alpha = exp(m_old - m_new)              # 旧累加器的修正系数
       l_new = alpha * l_old + sum(exp(qk_block - m_new))
       acc_new = alpha * acc_old + exp(qk_block - m_new) @ V_block

3. **Recomputation**（仅 backward）：bwd 不存中间矩阵，重算 P=exp(QK^T-m)/l

收益：
    - HBM 访问从 O(N²) 降到 O(N)
    - 显存从 O(N²) 降到 O(N)
    - 4096 token 单 head：朴素 32 MB → FlashAttention 72 KiB

本示例只实现 forward，且不含 causal mask，便于聚焦核心思想。

运行：
    python 04_flash_attention.py

参考：
    https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
    https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
    https://arxiv.org/abs/2205.14135 (FlashAttention v1)
    https://tridao.me/publications/flash2/flash2.pdf (FlashAttention v2)
"""

import math
import torch
import triton
import triton.language as tl


DEVICE = triton.runtime.driver.active.get_active_torch_device()


# -----------------------------------------------------------------------------
# 1. 朴素 attention（用于正确性对照 + 衡量加速比）
# -----------------------------------------------------------------------------
def naive_attention(q, k, v, sm_scale):
    """
    标准 attention，O(N²) 显存。

    输入形状均为 [BATCH, HEADS, N_CTX, HEAD_DIM]
    """
    # QK^T: [BATCH, HEADS, N_CTX, N_CTX]
    s = torch.matmul(q, k.transpose(-2, -1)) * sm_scale
    p = torch.softmax(s, dim=-1)
    o = torch.matmul(p, v)
    return o


# -----------------------------------------------------------------------------
# 2. Triton FlashAttention forward kernel
# -----------------------------------------------------------------------------
@triton.jit
def flash_attention_kernel(
    Q, K, V,                       # 输入张量指针
    sm_scale,                      # 1/sqrt(d)（标量，运行期）
    Out,                           # 输出张量指针
    # 各张量的 stride（batch, heads, seq, head_dim 四个维度的步长）
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vn, stride_vk,
    stride_oz, stride_oh, stride_om, stride_ok,
    # 形状
    Z, H, N_CTX,                   # batch, heads, seq_len
    HEAD_DIM: tl.constexpr,        # head dimension（编译期常量）
    BLOCK_M: tl.constexpr,         # Q 方向的 tile 大小
    BLOCK_N: tl.constexpr,         # K/V 方向的 tile 大小
):
    """
    每个 program 计算输出 O 的一个 [BLOCK_M, HEAD_DIM] 切片。

    Grid 维度：
        axis=0  →  Q 方向的 block 索引（共 ceil(N_CTX / BLOCK_M) 个）
        axis=1  →  (batch_idx, head_idx) 平铺，共 Z*H 个

    核心循环：
        Q_block 全程驻留 SRAM
        for K_block, V_block in zip(K_tiles, V_tiles):
            qk = Q_block @ K_block^T
            online softmax 更新 m_i, l_i
            acc = acc * alpha + softmax(qk) @ V_block
        O_block = acc / l_i
    """
    # ---- 1) 定位当前 program ----
    start_m = tl.program_id(0)               # 第几个 Q-block
    off_hz = tl.program_id(1)                # batch*heads 平铺索引
    off_z = off_hz // H                      # batch idx
    off_h = off_hz % H                       # head idx

    # 当前 (batch, head) 的 Q/K/V/O 起始偏移
    q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
    v_offset = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh
    o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh

    # ---- 2) 构造 Q-block 的指针并一次性加载 ----
    # Q-block 形状: [BLOCK_M, HEAD_DIM]，全程驻留寄存器 / SRAM
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)  # 当前 Q 行索引
    offs_k = tl.arange(0, HEAD_DIM)                     # head_dim 列索引

    q_ptrs = Q + q_offset + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk)
    # mask：Q 的 seq_len 可能不是 BLOCK_M 整数倍，越界行用 0 填充
    q_mask = offs_m[:, None] < N_CTX
    q = tl.load(q_ptrs, mask=q_mask, other=0.0)

    # ---- 3) 初始化 online softmax 累加器 ----
    # m_i: running max,        形状 [BLOCK_M]，初值 -inf
    # l_i: running sum of exp, 形状 [BLOCK_M]，初值 0
    # acc: running 加权 V 累加, 形状 [BLOCK_M, HEAD_DIM]，初值 0
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    # ---- 4) 主循环：遍历所有 K/V block ----
    offs_n_base = tl.arange(0, BLOCK_N)  # K/V 方向的局部偏移

    for start_n in range(0, N_CTX, BLOCK_N):
        # 当前 K/V block 的全局 seq 索引
        offs_n = start_n + offs_n_base

        # ---- 4.1) 加载 K block，形状 [BLOCK_N, HEAD_DIM] ----
        k_ptrs = K + k_offset + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
        k_mask = offs_n[:, None] < N_CTX
        k = tl.load(k_ptrs, mask=k_mask, other=0.0)

        # ---- 4.2) qk = Q_block @ K_block^T ----
        # Q: [BM, D], K: [BN, D]  →  qk: [BM, BN]
        # tl.dot 会自动转置第二个参数（这里我们手动转 K）
        qk = tl.dot(q, tl.trans(k))
        qk = qk * sm_scale

        # 越界位置（K 的 seq 越界）用 -inf 屏蔽，不影响 max/exp
        qk = tl.where(offs_n[None, :] < N_CTX, qk, -float('inf'))

        # ---- 4.3) Online softmax 更新 ----
        # 新 running max = max(旧 max, 本 block 内每行的 max)
        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))

        # alpha = exp(旧 m - 新 m)，用来修正旧累加器
        alpha = tl.exp(m_i - m_ij)

        # 本 block 的归一化 p（已减去新的 max，数值稳定）
        p = tl.exp(qk - m_ij[:, None])

        # 新 running sum = alpha * 旧 sum + 本 block 的 sum
        l_i = l_i * alpha + tl.sum(p, axis=1)

        # ---- 4.4) 加载 V block 并更新加权累加 ----
        # V: [BN, D]  →  p @ V: [BM, D]
        v_ptrs = V + v_offset + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
        v_mask = offs_n[:, None] < N_CTX
        v = tl.load(v_ptrs, mask=v_mask, other=0.0)

        # 修正旧 acc 后加上新 block 的贡献
        # 关键：p 此时已用 m_ij 归一化，与 acc 的归一化基准一致
        acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)

        # 更新 running max
        m_i = m_ij

    # ---- 5) 收尾：除以 running sum 得到真正的 softmax 加权和 ----
    acc = acc / l_i[:, None]

    # ---- 6) 写回 O ----
    o_ptrs = Out + o_offset + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok)
    o_mask = offs_m[:, None] < N_CTX
    tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_mask)


# -----------------------------------------------------------------------------
# 3. Python wrapper
# -----------------------------------------------------------------------------
def flash_attention(q, k, v, sm_scale=None):
    """
    FlashAttention forward（非 causal、单精度路径）。

    输入: q, k, v 形状均为 [BATCH, HEADS, N_CTX, HEAD_DIM]
    输出: o 形状 [BATCH, HEADS, N_CTX, HEAD_DIM]
    """
    assert q.is_cuda and k.is_cuda and v.is_cuda
    assert q.shape == k.shape == v.shape, "Q, K, V 形状必须一致"
    assert q.dtype == k.dtype == v.dtype

    BATCH, HEADS, N_CTX, HEAD_DIM = q.shape

    # 默认 sm_scale = 1/sqrt(d_k)，这是 attention 标准做法
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(HEAD_DIM)

    # HEAD_DIM 必须是 16 / 32 / 64 / 128 / 256 之一（Tensor Core 约束）
    assert HEAD_DIM in (16, 32, 64, 128, 256), \
        f"HEAD_DIM 必须是 2 的幂且 ≤ 256，实际 {HEAD_DIM}"

    o = torch.empty_like(q)

    # tile 大小：BLOCK_M = Q 方向，BLOCK_N = K/V 方向
    # 在 A100/H100 上 64/64 是不错的起点；生产代码应该用 autotune
    BLOCK_M = 64
    BLOCK_N = 64

    # Grid: (Q-block 数, batch*heads)
    grid = (triton.cdiv(N_CTX, BLOCK_M), BATCH * HEADS, 1)

    flash_attention_kernel[grid](
        q, k, v,
        sm_scale,
        o,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),
        BATCH, HEADS, N_CTX,
        HEAD_DIM=HEAD_DIM,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        num_warps=4,
        num_stages=2,
    )
    return o


# -----------------------------------------------------------------------------
# 4. 正确性验证
# -----------------------------------------------------------------------------
def test_correctness():
    torch.manual_seed(0)

    BATCH, HEADS, N_CTX, HEAD_DIM = 2, 4, 256, 64
    q = torch.randn(BATCH, HEADS, N_CTX, HEAD_DIM,
                    device=DEVICE, dtype=torch.float16)
    k = torch.randn_like(q)
    v = torch.randn_like(q)
    sm_scale = 1.0 / math.sqrt(HEAD_DIM)

    o_triton = flash_attention(q, k, v, sm_scale)
    o_torch = naive_attention(q, k, v, sm_scale)

    max_diff = torch.max(torch.abs(o_triton.float() - o_torch.float())).item()
    rel_diff = max_diff / torch.max(torch.abs(o_torch)).item()

    print(f"[正确性] shape={tuple(q.shape)} dtype={q.dtype}")
    print(f"  最大绝对误差: {max_diff:.3e}")
    print(f"  最大相对误差: {rel_diff:.3e}")

    # FlashAttention 在 fp16 下与朴素实现通常有 1e-2 量级误差（exp 累加）
    assert torch.allclose(o_triton, o_torch, atol=1e-2, rtol=1e-2), \
        "FlashAttention 与朴素 attention 差异过大"
    print("  [PASS]\n")

    # 再测一个 head_dim=128 的"现代 LLM"配置
    BATCH, HEADS, N_CTX, HEAD_DIM = 1, 8, 512, 128
    q = torch.randn(BATCH, HEADS, N_CTX, HEAD_DIM,
                    device=DEVICE, dtype=torch.float16)
    k = torch.randn_like(q)
    v = torch.randn_like(q)
    o_triton = flash_attention(q, k, v)
    o_torch = naive_attention(q, k, v, 1.0 / math.sqrt(HEAD_DIM))
    diff = torch.max(torch.abs(o_triton.float() - o_torch.float())).item()
    print(f"[正确性] shape={tuple(q.shape)} dtype={q.dtype}")
    print(f"  最大绝对误差: {diff:.3e}")
    assert torch.allclose(o_triton, o_torch, atol=1e-2, rtol=1e-2)
    print("  [PASS]\n")


# -----------------------------------------------------------------------------
# 5. 性能基准测试
# -----------------------------------------------------------------------------
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N_CTX'],
        x_vals=[2 ** i for i in range(9, 14)],  # 512 ~ 8192
        line_arg='provider',
        line_vals=['triton', 'torch'],
        line_names=['Triton FlashAttention', 'Torch (naive)'],
        styles=[('blue', '-'), ('green', '-')],
        ylabel='TFLOPS',
        plot_name='flash-attention-performance',
        args={'BATCH': 4, 'HEADS': 32, 'HEAD_DIM': 64, 'dtype': torch.float16},
    )
)
def benchmark(BATCH, HEADS, N_CTX, HEAD_DIM, provider, dtype):
    q = torch.randn((BATCH, HEADS, N_CTX, HEAD_DIM),
                    dtype=dtype, device=DEVICE)
    k = torch.randn_like(q)
    v = torch.randn_like(q)
    sm_scale = 1.0 / math.sqrt(HEAD_DIM)
    quantiles = [0.5, 0.2, 0.8]

    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: naive_attention(q, k, v, sm_scale),
            quantiles=quantiles,
        )
    else:  # triton
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: flash_attention(q, k, v, sm_scale),
            quantiles=quantiles,
        )

    # Attention FLOPs：约 4 * BATCH * HEADS * N_CTX^2 * HEAD_DIM
    # （QK^T: 2BHN²D + softmax 忽略 + PV: 2BHN²D = 4BHN²D）
    flops_per_op = 4 * BATCH * HEADS * N_CTX * N_CTX * HEAD_DIM
    perf = lambda ms: flops_per_op * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)


# -----------------------------------------------------------------------------
# 6. 主入口
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    print("=" * 60)
    print("Triton 示例 04：FlashAttention 简化版")
    print("=" * 60)

    test_correctness()

    print("[Benchmark] 跑 TFLOPS 对比 (BATCH=4, HEADS=32, HEAD_DIM=64)")
    print("预期：seq 越长 Triton 优势越明显（朴素实现 O(N²) 显存可能 OOM）")
    print("说明：本示例为教学简化版，未启用 warp_specialize / TMA / FP8 等")
    print("      Hopper 高级特性，性能仅供参考")
    print("-" * 60)
    try:
        benchmark.run(print_data=True, show_plots=False)
    except torch.cuda.OutOfMemoryError as e:
        print(f"[OOM] 朴素实现在长序列上爆显存，这正是 FlashAttention 的优势所在")
        print(f"      原始错误: {e}")
