EXPERIMENT REPORT

实验七:梯度检查点 (Gradient Checkpointing) 的时空权衡

2026-01-01 10 MIN READ

1. 实验 7:基础实验

注:

  • 本实验环境基于 摩尔线程 (Moore Threads) MUSA 架构,但也兼容 CUDA 环境进行对比。
  • Happy Path:标准开启 GC,用计算换空间。
  • Sad Path:错误搭配 CPU Offload 导致的 PCIe 拥堵。

1.1 第一阶段:环境准备与算力选型

目标:在国产算力平台上部署 LLaMA-3-8B(2026年标配中小模型),准备基准测试环境。

  1. 算力选型

    • 推荐MTT S4000 (48GB) x 1。这是我们目前的旗舰单卡,48G 显存对于 8B 模型绰绰有余,方便我们测出 Batch Size 的极限。
    • 替代:若使用 N 卡,请选择 A100/H100 (40GB/80GB) 以进行横向对标。
    • 软件栈:MUSA SDK 3.1 + PyTorch (MUSA版) + DeepSpeed (MUSA适配版)。
  2. 环境确认: 打开终端,确认 MUSA 驱动状态。

    hljs bash
    1
    mthreads-smi  # 类似于 nvidia-smi,确认显存空闲且驱动版本正确
    
  3. 依赖安装: 我们需要最新的 Transformers 库来支持自动化的 GC 注入。

    hljs bash
    1
    2
    3
    pip install transformers accelerate deepspeed
    # 确认安装了 torch_musa
    python -c "import torch; import torch_musa; print(torch_musa.is_available())"
    

1.2 第二阶段:构建实验代码 (The Benchmark)

目标:编写脚本 gc_benchmark.py,对比开启 GC 前后的最大 Batch Size 和吞吐量(Samples/Sec)。

  1. 创建实验脚本

    hljs bash
    1
    touch gc_benchmark.py
    
  2. 编写核心逻辑: 不要手动去 wrap 模型,使用 HF 的封装。重点在于捕获 OOM 并记录数据。

    hljs python
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    import torch
    import torch_musa
    import time
    from transformers import AutoModelForCausalLM, AutoConfig
    from torch.utils.checkpoint import checkpoint
    
    # 模拟 LLaMA-3-8B 配置(使用随机初始化避免下载权重耗时)
    config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B")
    
    def benchmark(use_gc, batch_size, seq_len=4096):
        print(f"\n--- Testing: GC={use_gc}, BS={batch_size} ---")
        
        # 1. 初始化模型到 MUSA 设备
        model = AutoModelForCausalLM.from_config(config).to("musa")
        model.train()
        
        # 2. 开启 Gradient Checkpointing
        if use_gc:
            # 关键点:MUSA 驱动对部分算子有特殊优化,use_reentrant=False 通常更稳定
            model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
        
        # 3. 构造 Dummy Data
        input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to("musa")
        labels = input_ids.clone()
        
        try:
            # Warmup
            _ = model(input_ids, labels=labels).loss.backward()
            optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
            optimizer.zero_grad()
            
            # Measurement
            start_time = time.time()
            for _ in range(5): # 跑5个step取平均
                loss = model(input_ids, labels=labels).loss
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            
            end_time = time.time()
            avg_time = (end_time - start_time) / 5
            print(f"Success! Avg Time per Step: {avg_time:.4f}s")
            
            # 获取当前显存占用 (MUSA API)
            mem_alloc = torch.musa.memory_allocated() / 1024**3
            print(f"Memory Allocated: {mem_alloc:.2f} GB")
            return avg_time
            
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                print("OOM Triggered!")
                return None
            else:
                raise e
        finally:
            # 清理显存,防止影响下一轮
            del model, input_ids, labels
            torch.musa.empty_cache()
    
    if __name__ == "__main__":
        # 对比测试:逐步增加 Batch Size
        batch_sizes = [1, 2, 4, 8, 16, 32]
        
        print(">>> 阶段 1: 关闭 GC (Baseline)")
        for bs in batch_sizes:
            res = benchmark(use_gc=False, batch_size=bs)
            if res is None: break # 一旦OOM就停止
            
        print("\n>>> 阶段 2: 开启 GC (Activation Recomputation)")
        for bs in batch_sizes:
            res = benchmark(use_gc=True, batch_size=bs)
            if res is None: break
    

1.3 第三阶段:执行与观测

操作提示:你需要同时关注终端输出和系统级监控。

  1. 开启监控(分屏): 在第二个终端窗口,实时观察 GPU 负载。在 MUSA 架构上,重点看 Gpu UtilMem Util

    hljs bash
    1
    watch -n 1 mthreads-smi
    
  2. 运行实验

    hljs bash
    1
    python gc_benchmark.py
    
  3. 实验预期结果

    • 显存:开启 GC 后,同 Batch Size 下显存占用应下降 50%-60%
    • 速度:开启 GC 后,单步时间应增加 25%-30%
    • 极限:BS=8 时若 Baseline OOM,开启 GC 后应能跑到 BS=16 甚至 BS=24。

1.4 第四阶段:实验结果分析指南

当脚本运行时,请重点观察以下数据:

  1. Throughput Drop (%):计算公式 (Time_GC - Time_Base) / Time_Base
    • 如果 > 35%,说明重计算过程中的算子调度在 MUSA 上存在瓶颈。
  2. Memory Saving (%)
    • 如果 < 40%,检查是否只 Checkpoint 了部分层(Transformers 默认策略通常是全层)。

2. 实验 7:进阶实验

为什么需要进阶实验? 基础实验只展示了理想情况。在国产化适配的深水区,开发者常因为切分粒度过细IO瓶颈导致性能崩塌。

2.1 进阶 1:颗粒度陷阱 (Granularity Hell)

  • 缺失点:基础实验是对整个 Transformer Layer 做 Checkpoint。如果我们要手动优化,对更小的算子(如 Attention 内部)做 Checkpoint 会怎样?
  • 操作:手动插入 checkpoint,切碎计算图。
  • 观察
    • N 厂现象:CUDA Kernel Launch 开销低,细粒度虽有损耗但可控。
    • 国产挑战:MUSA 目前的 Kernel Launch overhead 相对较高。如果你切得太碎,启动 Kernel 的时间可能比重计算的时间还长,导致 GPU 利用率锯齿状波动。

执行步骤

  1. 创建文件 gc_granularity.py 目标:对比 Layer 级 GC 与 细粒度 GC 的性能差异。

    hljs python
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    # 伪代码逻辑,展示核心差异
    import torch
    import torch.nn as nn
    from torch.utils.checkpoint import checkpoint
    
    class FineGrainedBlock(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear1 = nn.Linear(4096, 4096).to("musa")
            self.linear2 = nn.Linear(4096, 4096).to("musa")
            self.act = nn.GELU()
    
        def forward(self, x):
            # 错误示范:粒度太细,导致大量 kernel launch 开销
            # MUSA 驱动在频繁切换上下文时会在此处累积延迟
            x = checkpoint(self.linear1, x, use_reentrant=False) 
            x = self.act(x)
            x = checkpoint(self.linear2, x, use_reentrant=False)
            return x
    
    # 运行此模块并测量时间,对比将整个 forward 包裹在一次 checkpoint 中的时间
    # ... (省略 boilerplate 代码)
    
  2. 运行与分析: 运行后你会发现,虽然显存省得更多极致,但训练速度可能下降 50% 以上,得不偿失。


2.2 进阶 2:破坏性测试——IO 瓶颈 (The I/O Wall)

  • 场景:结合 DeepSpeed CPU Offload 使用 GC。
  • 操作:在 ds_config 中开启 cpu_offload 并同时开启 GC。
  • 国产挑战
    • MUSA 卡的 PCIe 带宽利用率通常比 N 卡敏感。
    • GC 需要重新计算激活值 -> 需要输入数据 -> 如果输入数据被 Offload 到了 CPU -> 触发 PCIe 通信
    • 结果:计算单元在等数据(Stall),GPU 利用率暴跌。

3. 实验总结与核心知识点

3.1【核心结论】

GC 是国产显卡的“穷人版扩容卡”。 在显存带宽 (HBM) 紧缺而算力 (FLOPS) 相对充裕的 MUSA 架构上,用 30% 的算力损耗换取 2 倍的 Batch Size,是极具性价比的生存策略。

3.2【技术解剖:MUSA vs CUDA】

  1. N 厂方案:配合 FlashAttention,CUDA 生态已经做到了 Layer 级甚至 Block 级的智能选择,部分场景下 GC 损耗可压至 15% 以内。
  2. 国产化方案:我们必须避免“细粒度”陷阱。由于驱动层面的 Overhead,MUSA 更适合 Coarse-grained (粗粒度) 的 Checkpointing(即以整个 Transformer Layer 为单位)。
  3. 编译器差异:在 MUSA 上,重计算可能导致指令重排。如果发现 Loss 曲线微小抖动,通常是 FMA(融合乘加)指令在两次计算中的执行顺序不同导致的精度扰动,属于正常现象。

3.3【关键概念 (Knowledge Points)】

  • 激活重计算 (Rematerialization):GC 的本质。以时间(FLOPs)换空间(VRAM)。
  • 算术强度 (Arithmetic Intensity):开启 GC 实际上提升了模型的算术强度(计算量增加,访存量相对减少)。这对于显存带宽受限的国产卡其实是利好。
  • OOM 边界 (OOM Boundary):硬件决定下限,GC 决定上限。不懂 GC 的调优工程师,在国产卡上寸步难行。

老鸟锐评: 别只盯着显存占用率看。如果你发现开启 GC 后,训练速度下降超过 40%,别急着怪架构,先去检查你的 DataLoader 是不是因为 CPU 忙着帮 GPU 做数据搬运(Offload)而卡住了——这才是系统工程。