EXPERIMENT REPORT

实验五:训练显存解构与梯度累积 (Gradient Accumulation)

2025-12-28 10 MIN READ

1. 实验五:基础实验

注:

  • 本实验环境假设为 RTX 4090 (24GB)MTT S4000 (48GB)
  • 现在的年份是 2026 年,PyTorch 生态已相对成熟,但在国产卡上“版本对齐”依旧是基本功。
  • Happy Path:通过技巧绕过显存墙。
  • Sad Path:朴素训练,开局即炸。

1.1 第一阶段:实例租赁与环境准备

目标:准备全量微调(Full Fine-Tuning)环境。

  1. 租赁实例

    • 算力选型:在 AutoDL 或类似平台,选择单卡 24GB+ 显存实例。
    • 镜像选择:PyTorch 2.4+ / Python 3.10 / CUDA 12.x(如果是 N 卡)。如果是国产卡,请加载对应的 MUSA Toolkit 专用镜像。
  2. 安装核心依赖: 全量微调不同于推理,我们需要 deepspeed (作为对照组或优化手段) 和 accelerate

    hljs bash
    1
    pip install transformers accelerate deepspeed bitsandbytes modelscope
    

1.2 第二阶段:模型下载

目标:准备 Qwen2.5-1.5B。选小模型是因为我们要解剖显存,大模型在单卡上不通过 Offload 根本跑不起来,没法分析“显存构成”,只能分析“报错日志”。

  1. 创建下载脚本 download_model.py

    hljs python
    1
    2
    3
    4
    from modelscope import snapshot_download
    # 1.5B 参数量适中,适合做显存解剖手术
    model_dir = snapshot_download('qwen/Qwen2.5-1.5B', cache_dir='/root/autodl-tmp')
    print(f"Model downloaded to: {model_dir}")
    
  2. 执行下载

    hljs bash
    1
    python download_model.py
    

1.3 第三阶段:构建实验代码 (The Experiment)

目标:编写脚本 train_vram_analysis.py。通过对比“朴素 AdamW”和“梯度累积”两种模式,量化显存开销。

  1. 编写代码

    hljs python
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    import torch
    from transformers import AutoModelForCausalLM, AutoConfig
    import time
    
    # 模拟配置
    MODEL_PATH = "/root/autodl-tmp/qwen/Qwen2.5-1.5B" 
    
    def print_memory_stats(stage):
        torch.cuda.synchronize()
        # 显存单位转为 GB
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"[{stage}] Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")
    
    def run_experiment(batch_size, grad_accum_steps, use_checkpointing=False):
        print(f"\n=== Experiment: BS={batch_size}, GA={grad_accum_steps}, CKPT={use_checkpointing} ===")
        
        # 1. 加载模型 (BF16)
        # 老鸟锐评:BF16 是训练标配,FP16 容易溢出,FP32 显存翻倍。
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH, 
            torch_dtype=torch.bfloat16, 
            device_map="cuda",
            attn_implementation="flash_attention_2" # N卡必开,国产卡看驱动支持情况
        )
        
        if use_checkpointing:
            model.gradient_checkpointing_enable()
            
        print_memory_stats("Model Loaded (Static)")
        
        # 2. 优化器 (AdamW)
        # 关键点:AdamW 需要维护 Momentum 和 Variance,都是 FP32,这是显存大户
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
        
        # 3. 模拟输入 (Seq Len = 512)
        input_ids = torch.randint(0, 1000, (batch_size, 512)).cuda()
        
        try:
            # 4. Forward & Backward Loop
            model.train()
            
            # 记录开始前的显存快照
            torch.cuda.memory._record_memory_history(max_entries=100000)
            
            optimizer.zero_grad()
            
            # 模拟梯度累积
            for i in range(grad_accum_steps):
                outputs = model(input_ids, labels=input_ids)
                loss = outputs.loss / grad_accum_steps
                loss.backward()
                print_memory_stats(f"Step {i+1} Backward Done")
                
            optimizer.step()
            print_memory_stats("Optimizer Step Done")
            
            # 导出显存快照用于分析
            torch.cuda.memory._dump_snapshot(f"snapshot_bs{batch_size}_ga{grad_accum_steps}.pickle")
            print(">>> SUCCESS: Training step completed.")
            
        except RuntimeError as e:
            print(f">>> OOM TRIGGERED: {str(e)}")
        finally:
            del model, optimizer, input_ids
            torch.cuda.empty_cache()
    
    if __name__ == "__main__":
        # Sad Path: 强行大 Batch,试图直接炸显存
        # 预期:在 Backward 阶段激活值爆炸导致 OOM
        run_experiment(batch_size=4, grad_accum_steps=1, use_checkpointing=False)
        
        # Happy Path: 梯度累积 + 小 Batch
        # 预期:显存平稳,通过时间换空间
        run_experiment(batch_size=1, grad_accum_steps=4, use_checkpointing=False)
    

1.4 第四阶段:执行与观测

  1. 开启监控: 在第二个终端输入:

    hljs bash
    1
    2
    3
    watch -n 0.5 nvidia-smi 
    # 或者如果是摩尔线程卡:
    # watch -n 0.5 musa-smi
    
  2. 运行实验

    hljs bash
    1
    python train_vram_analysis.py
    
  3. 结果示例

    hljs output
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    === Experiment: BS=4, GA=1, CKPT=False ===
    [Model Loaded (Static)] Allocated: 3.05 GB | Reserved: 3.20 GB
    ...
    >>> OOM TRIGGERED: CUDA out of memory. Tried to allocate 4.50 GiB ...
    
    === Experiment: BS=1, GA=4, CKPT=False ===
    [Model Loaded (Static)] Allocated: 3.05 GB | Reserved: 3.20 GB
    [Step 1 Backward Done] Allocated: 7.80 GB | Reserved: 9.10 GB
    ...
    >>> SUCCESS: Training step completed.
    

1.5 第五阶段:实验结果分析指南

数据解剖

  1. 静态显存 (Static Memory)

    • 模型权重 (1.5B * 2B = 3GB) + 显存碎片预留。这是“房租”,无论你跑不跑都在那。
    • 老鸟锐评:如果你的空载显存占用异常高,检查一下是不是 PyTorch Context 初始化占了 1GB,或者显卡驱动有内存泄漏。
  2. 动态显存 (Dynamic Memory)

    • Activation (激活值):Backward 阶段的峰值。BS=4 时,激活值可能瞬间达到 10GB+,直接撑爆显存。
    • Optimizer States:AdamW 的状态需要在 step() 时具体化,瞬间占用 $1.5B \times 12B \approx 18GB$ (如果是 FP32)。
    • 结论:这就是为什么全量微调 1.5B 模型在 24GB 卡上如果不做优化(Offload/LoRA)几乎不可能。

2. 实验五:进阶实验

为什么需要进阶实验? 简单的 Batch Size 调整只是入门。真正的显存杀手是 Sequence Length ($O(N^2)$)显存碎片 (Fragmentation)。这是区分“能跑 demo”和“能落地生产”的分水岭。

2.1 进阶 1:序列长度爆炸与显存碎片

  • 缺失点:基础实验固定了 SeqLen=512。实际业务中,RAG 或长文本往往需要 4k-32k。
  • 操作:固定 Batch Size = 1,指数级增加 Sequence Length,直到 OOM。
  • 观察
    • N 厂现象:FlashAttention 使得显存增长接近线性,显存利用率极高,碎片少。
    • 国产挑战:如果没有完美的算子融合(Operator Fusion),大量中间 Tensor 的申请与释放会导致严重的显存碎片。明明显示还有 4GB 空闲,却分配不出连续的 1GB 空间。

执行步骤

  1. 创建文件 exp5_seq_torture.py

    目标:测试长序列下的 Activation 增长与碎片化极限。

    hljs python
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    import torch
    from transformers import AutoModelForCausalLM
    
    def torture_test():
        model = AutoModelForCausalLM.from_pretrained(
            "/root/autodl-tmp/qwen/Qwen2.5-1.5B", 
            torch_dtype=torch.bfloat16,
            device_map="cuda",
            attn_implementation="eager" # 关闭 FlashAttn,模拟最坏情况或国产卡早期驱动
        )
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
        
        seq_lengths = [512, 1024, 2048, 4096, 8192]
        
        for seq_len in seq_lengths:
            print(f"\n>>> Testing Sequence Length: {seq_len}")
            try:
                torch.cuda.empty_cache()
                input_ids = torch.randint(0, 1000, (1, seq_len)).cuda()
                
                # Forward
                outputs = model(input_ids, labels=input_ids)
                
                # Backward (这里是激活值显存峰值)
                outputs.loss.backward()
                
                print(f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
                print(f"Reserved:  {torch.cuda.memory_reserved()/1024**3:.2f} GB")
                
                optimizer.zero_grad() # 清空梯度,准备下一轮
                
            except RuntimeError as e:
                print(f"!!! OOM at SeqLen {seq_len}: {str(e)}")
                break
                
    if __name__ == "__main__":
        torture_test()
    
  2. 运行文件

    hljs bash
    1
    python exp5_seq_torture.py
    
  3. 观测重点: 注意观察 ReservedAllocated 的差值。如果差值越来越大,说明碎片化严重。这是 MUSA 架构早期驱动重点优化的方向——Cuda 的 Caching Allocator 非常成熟,能有效复用碎片。


3. 实验总结与核心知识点

3.1【核心结论】

硬件决定下限,软件决定上限。通过梯度累积 (Gradient Accumulation) 用时间换空间,我们可以绕过显存墙。但在全量微调场景下,优化器状态 (Optimizer States) 才是真正的 Boss,必须配合 8-bit Optimizer 或 Offload 技术才能在消费级卡上通关。

3.2【技术解剖:显存三态】

  1. 静态显存 (Static):模型参数。雷打不动,开机即占。
  2. 动态显存 (Activations):随 Batch Size 和 Sequence Length 剧烈波动。FlashAttention 是这一领域的救世主。
  3. 临时显存 (Buffers/Workspace):PyTorch 和底层 Kernel 运行时的临时空间。老鸟锐评:国产卡由于算子融合度不够,这部分开销往往比 N 卡大,导致同样的 24GB,在 N 卡能跑,在国产卡上 OOM。

3.3【关键概念 (Knowledge Points)】

  • Gradient Accumulation:分期付款。一次算不起,分 16 次算,最后一次性结账(更新权重)。
  • Activation Checkpointing (重计算):用计算换显存。不存中间结果,反向传播用到时再算一遍。这是在 24GB 卡上跑大模型的必备技能。
  • Fragmentation (碎片化):显存里的“公摊面积”。看起来还有空地,但太碎了停不下大车。这是驱动层优化的核心战场。