EXPERIMENT REPORT

实验十:FlashAttention 性能分析与 Kernel Profiling

2026-01-10 10 MIN READ

1. 实验 10:基础实验

注:

  • 本实验建议在摩尔线程 MTT S4000 或 NVIDIA RTX 4090 上运行,显存带宽越敏感越好。
  • Happy Path:开启 FlashAttention (v2)。
  • Sad Path:关闭 FlashAttention,退回标准 Attention。

1.1 第一阶段:环境准备与工具链检查

目标:确保 PyTorch 版本支持 SDPA (Scaled Dot Product Attention) 且 Profiler 工具可用。

  1. 环境检查: FlashAttention 2 需要较新的 PyTorch (2.1+) 和 CUDA 环境。如果是 MUSA 架构,请确保 musa_toolkits 已加载。

  2. 安装/检查依赖

    hljs bash
    1
    2
    3
    # 检查 PyTorch 版本
    python -c "import torch; print(torch.__version__)"
    # 预期输出:2.1.0 或更高
    

1.2 第二阶段:构建 Profiling 代码 (The Probe)

目标:编写带有预热 (Warmup)显式同步 (Synchronization) 的测试脚本。

  1. 创建主实验脚本

    hljs bash
    1
    touch profile_flash_attn.py
    
  2. 编写代码

    这里不仅仅是跑通,重点在于捕捉真实的 GPU 耗时。如果不做 Warmup 和 Sync,你测出来的都是 CPU 的调度开销。

    hljs python
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    import torch
    import torch.nn.functional as F
    from torch.profiler import profile, record_function, ProfilerActivity
    import os
    
    # 模拟 MUSA/CUDA 设备兼容
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # 如果是 MUSA 环境,通常兼容 CUDA 语义,或者使用 torch_musa
    
    def run_attention(q, k, v, use_fa=True):
        # 强制使用或禁用 FlashAttention
        # enable_math=False, enable_mem_efficient=False 迫使 PyTorch 走 C++ fallback (慢) 或 FA (快)
        with torch.backends.cuda.sdp_kernel(enable_flash=use_fa, 
                                            enable_math=not use_fa, 
                                            enable_mem_efficient=not use_fa):
            return F.scaled_dot_product_attention(q, k, v)
    
    # 构造数据 (Batch=32, Heads=16, SeqLen=4096, Dim=128) -> 典型的 LLM 负载
    dtype = torch.float16
    q = torch.randn(32, 16, 4096, 128, device=device, dtype=dtype)
    k = torch.randn(32, 16, 4096, 128, device=device, dtype=dtype)
    v = torch.randn(32, 16, 4096, 128, device=device, dtype=dtype)
    
    # 1. 预热 (Warmup) - 关键步骤:填满 Cache,拉高 GPU 频率
    print("Warmup starting...")
    for _ in range(10):
        run_attention(q, k, v, use_fa=True)
    torch.cuda.synchronize() # 确保预热完成
    
    # 2. Profiling Happy Path (FA On)
    print("Profiling FlashAttention ON...")
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                 record_shapes=True) as prof:
        with record_function("Attention_FA_On"):
            run_attention(q, k, v, use_fa=True)
            torch.cuda.synchronize() # 显式同步,否则测不准
    prof.export_chrome_trace("trace_fa_on.json")
    
    # 3. Profiling Sad Path (Standard Attention)
    print("Profiling Standard Attention (FA Off)...")
    # 再次预热一下,清洗状态
    run_attention(q, k, v, use_fa=False)
    torch.cuda.synchronize()
    
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                 record_shapes=True) as prof:
        with record_function("Attention_FA_Off"):
            run_attention(q, k, v, use_fa=False)
            torch.cuda.synchronize()
    prof.export_chrome_trace("trace_fa_off.json")
    
    print("Profiling done. Check JSON files in Chrome://tracing")
    

1.3 第三阶段:执行与观测

  1. 运行脚本

    hljs bash
    1
    python profile_flash_attn.py
    
  2. 可视化分析

    • 打开 Chrome 浏览器,输入 chrome://tracing
    • 加载 trace_fa_off.json (Sad Path)。
    • 加载 trace_fa_on.json (Happy Path)。
  3. 预期形态对比

    • Sad Path (标准 Attention):你会看到 GPU 轨道上是一连串密集的短条:MatMul -> Softmax -> MatMul。这就是显存的“心电图”,中间充满了读写 HBM 的间隙。
    • Happy Path (FA):你会看到一个巨大的、单一的长条(例如 flash_fwd_kernel)。此时 SRAM 是忙碌的,但 HBM 读写反而是平滑的。

1.4 第四阶段:实验结果分析指南

当脚本运行时,请重点观察以下数据,验证“理论背景”

  1. Kernel 数量与融合

    • 标准模式下,Attention 至少由 3-5 个 Kernel 组成。
    • FA 模式下,应该只有 1 个大 Kernel。这就是算子融合 (Kernel Fusion) 的极致。
  2. HBM 带宽利用率

    • 如果你有 Nsight Systems 或摩尔线程的 profiler 工具,你会发现 FA 的 HBM 带宽利用率并不一定时刻 100%,因为瓶颈转移到了计算单元(Tensor Core/MUSA Core)的吞吐上。这说明我们成功翻越了“显存墙”。

2. 实验 10:进阶实验

为什么需要进阶实验?

很多开发者以为调了 F.scaled_dot_product_attention 就万事大吉了。作为架构师,我要告诉你硬件最讨厌什么:不对齐 (Misalignment)填充 (Padding)。这在国产卡适配中尤为致命。

2.1 进阶 1:内存对齐陷阱 (The Alignment Trap)

  • 缺失点:基础实验用了标准的 Head_Dim=128,这是 GPU 最舒服的姿势。
  • 操作:将 Head Dimension 改为 129
  • 观察
    • N 厂现象:虽然也有性能回退,但其编译器 (NVCC/PTX) 极其成熟,会尽量做拆分优化。
    • 国产挑战:MUSA 或其他国产架构的编译器,面对非 2 的幂次方(特别是奇数)维度时,向量化加载 (Vectorized Load) 可能失效,退化为标量读取,带宽利用率直接腰斩。

执行步骤

  1. 创建文件 exp10_alignment.py

    目标:对比 Head_Dim=128 (对齐) 与 Head_Dim=129 (非对齐) 的性能差异。

    hljs python
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    import torch
    import time
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    def benchmark_dim(dim, n_iter=100):
        # 保持总计算量近似,主要测带宽对齐效率
        q = torch.randn(16, 16, 2048, dim, device=device, dtype=torch.float16)
        k = torch.randn(16, 16, 2048, dim, device=device, dtype=torch.float16)
        v = torch.randn(16, 16, 2048, dim, device=device, dtype=torch.float16)
        
        # Warmup
        for _ in range(10):
            torch.nn.functional.scaled_dot_product_attention(q, k, v)
        torch.cuda.synchronize()
        
        start = time.time()
        for _ in range(n_iter):
            torch.nn.functional.scaled_dot_product_attention(q, k, v)
        torch.cuda.synchronize()
        end = time.time()
        return (end - start) * 1000 / n_iter # ms
    
    time_128 = benchmark_dim(128)
    time_129 = benchmark_dim(129)
    
    print(f"Head_Dim=128 Latency: {time_128:.3f} ms")
    print(f"Head_Dim=129 Latency: {time_129:.3f} ms")
    print(f"Performance Drop: {((time_129 - time_128) / time_128) * 100:.1f}%")
    
  2. 运行文件

    hljs bash
    1
    python exp10_alignment.py
    
  3. 预期结果分析: 如果性能下降超过 20%,说明该架构的底层算子库对非对齐内存访问缺乏 PaddingCoalescing 优化。这是我们在做 MUSA 算子库时最头疼的“脏活”。


3. 实验总结与核心知识点

3.1【核心结论】

FlashAttention 的本质,是用多余的计算(Recomputation)换取稀缺的带宽(Bandwidth)。 在现代 GPU 上,算力是廉价的,IO 才是昂贵的。

3.2【技术解剖:IO 感知】

  1. N 厂方案:利用 Tensor Core 的极高算力,配合 L2 Cache 和 Shared Memory 的精细管理,将 $O(N^2)$ 的 HBM 访问降级为 $O(N)$。
  2. 国产化挑战:我们的痛点往往不在于理论算力,而在于寄存器压力 (Register Pressure)。FA 的 Kernel 极大,需要大量寄存器存状态,如果寄存器不够,就会发生 Spill (溢出到显存),导致优化失效。这也是为什么你在国产卡上跑 FA 有时并未获得预期倍数的原因。

3.3【关键概念 (Knowledge Points)】

  • IO-Awareness (IO 感知):算法设计必须考虑数据搬运成本,而非仅仅是 FLOPs。
  • Kernel Fusion (算子融合):消灭中间变量的显存读写,是 GPU 性能优化的终极手段。
  • Memory Wall (显存墙):当你看到 Profiler 里 HBM 跑满而 SM 利用率不高时,你就撞墙了。

老鸟锐评: 做完这个实验,如果你能理解为什么把 Head Dim 设为 129 是“犯罪”,那你才算真正入了 GPU 编程的门。造显卡像造车,堆料容易,修路(驱动与编译器优化)难。