EXPERIMENT REPORT

实验六:混合精度 (AMP) 的稳定性与 Loss Scaling

2025-12-29 10 MIN READ

1. 实验六:基础实验

老鸟注

  • 本实验不仅是跑通代码,更是为了验证显卡的 FPU(浮点单元)是否在“偷懒”。
  • Happy Path:BF16 直通,无需 Scaling,丝般顺滑。
  • Sad Path:FP16 不加 Scaler,梯度下溢(Underflow)导致模型无法收敛,或者梯度爆炸变成 NaN。
  • 环境假设:年份 2026,假设你正在使用摩尔线程 S4000/S5000 系列或 NVIDIA 对标卡。

1.1 第一阶段:环境与 MUSA 适配检查

目标:确保驱动层正确加载,且 PyTorch 能识别到 MUSA 设备(或 CUDA 设备)。

  1. 环境自检: 在 2026 年,我们的 torch_musa 已经高度兼容原生 PyTorch,但底层指令集映射依然是关键。

    打开终端,输入:

    hljs bash
    1
    python -c "import torch; print(f'Device: {torch.device(\'musa\' if torch.musa.is_available() else \'cuda\')}, BF16 Support: {torch.cuda.is_bf16_supported() if torch.cuda.is_available() else \'Unknown\'}')"
    

    如果输出 Device 为 cpu,说明驱动没挂上,别往下跑了,去修驱动。

  2. 安装依赖

    hljs bash
    1
    2
    pip install torch torchvision
    # 如果是摩尔线程环境,请确保 source 了 musa 的 sdk 环境变量
    

1.2 第二阶段:构建“梯度杀手”模型

目标:为了测试稳定性,我们需要一个容易产生极大或极小梯度的“脆弱”模型。不用下载 BERT,我们手搓一个简易的 Transformer Block。

  1. 创建实验脚本

    hljs bash
    1
    touch amp_stability_test.py
    
  2. 编写核心逻辑: 这个脚本会对比三种情况:

    1. BF16 (Happy Path):不带 Scaler,应该正常。
    2. FP16 (Sad Path):不带 Scaler,预期梯度变 0 或 Loss 变 NaN。
    3. FP16 + Scaler (Recovery):带 Scaler,预期恢复正常。
    hljs python
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    import torch
    import torch.nn as nn
    import math
    
    # 自动适配 MUSA 或 CUDA
    try:
        import torch_musa
        device = torch.device("musa")
    except ImportError:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"Running on: {device}")
    
    # 一个容易梯度爆炸/消失的简易 Transformer 层
    class FragileModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear1 = nn.Linear(4096, 4096)
            self.activation = nn.GELU() # 涉及 exp 计算,FP16 敏感区
            self.linear2 = nn.Linear(4096, 1024)
        
        def forward(self, x):
            # 故意不做 LayerNorm,增加数值不稳定性
            return self.linear2(self.activation(self.linear1(x)))
    
    def run_test(mode_name, use_bf16, use_scaler):
        print(f"\n--- Testing Mode: {mode_name} ---")
        model = FragileModel().to(device)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # SGD 对梯度更敏感
        scaler = torch.cuda.amp.GradScaler(enabled=use_scaler)
        
        # 模拟输入:方差较大的数据
        data = torch.randn(64, 4096, device=device) * 10.0 
        target = torch.randn(64, 1024, device=device)
    
        dtype = torch.bfloat16 if use_bf16 else torch.float16
        
        for step in range(5):
            optimizer.zero_grad()
            
            with torch.autocast(device_type=device.type, dtype=dtype):
                output = model(data)
                # 使用 MSE Loss,如果数值过大容易溢出
                loss = nn.MSELoss()(output, target)
            
            # 捕获异常
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Step {step}: Loss is NaN/Inf! (FAILURE)")
                break
    
            if use_scaler:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                print(f"Step {step}: Loss={loss.item():.4f}, Scale={scaler.get_scale()}")
            else:
                loss.backward()
                # 检查梯度范数
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                if torch.isnan(grad_norm) or grad_norm == 0:
                    print(f"Step {step}: Gradient anomaly detected (Norm={grad_norm.item()})")
                optimizer.step()
                print(f"Step {step}: Loss={loss.item():.4f}")
    
    if __name__ == "__main__":
        # 1. BF16 Happy Path
        run_test("BF16 (No Scaler)", use_bf16=True, use_scaler=False)
        
        # 2. FP16 Sad Path (容易挂)
        run_test("FP16 (No Scaler)", use_bf16=False, use_scaler=False)
        
        # 3. FP16 Recovery Path (应该能跑)
        run_test("FP16 (With Scaler)", use_bf16=False, use_scaler=True)
    

1.3 第三阶段:执行与观测

  1. 运行实验

    hljs bash
    1
    python amp_stability_test.py
    
  2. 预期输出示例

    hljs output
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    Running on: musa:0
    
    --- Testing Mode: BF16 (No Scaler) ---
    Step 0: Loss=104.2312
    Step 1: Loss=98.1234
    ... (正常下降)
    
    --- Testing Mode: FP16 (No Scaler) ---
    Step 0: Gradient anomaly detected (Norm=nan) 
    Step 0: Loss=nan (FAILURE)
    (或者梯度直接为0,Loss不动)
    
    --- Testing Mode: FP16 (With Scaler) ---
    Step 0: Loss=104.2312, Scale=65536.0
    Step 1: Loss=99.4321, Scale=32768.0 (如果检测到 NaN,Scale 会减半)
    

1.4 第四阶段:实验结果分析指南

重点观察对象:

  1. FP16 无 Scaler 时的死法:是直接 NaN (溢出),还是梯度为 0 (下溢)?这反映了你的数据分布和模型结构触碰了 FP16 的哪根红线。
  2. Scaler 的跳变:如果 Scale 因子频繁减半(backoff),说明硬件通过不了当前精度的计算。在国产卡上,如果 Scale 因子降到了 1.0 依然 NaN,那可能不是精度问题,而是驱动里的算子实现有 Bug。

2. 实验六:进阶实验

“能跑”和“跑得快”是两码事。很多国产卡的 AMP 也就是看着热闹,实际上底层全是 FP32 在回退。

2.1 进阶 1:算子回退与“假 AMP”检测

  • 缺失点:基础实验看不出速度优势。
  • 痛点:在 MUSA 架构早期,我们常遇到 PyTorch 前端分发了 FP16 指令,但后端没有对应的 Kernel,导致驱动隐式地把数据 Cast 回 FP32 计算再 Cast 回来。这比直接跑 FP32 还慢。
  • 操作:通过矩阵乘法(GEMM)压力测试,对比 FP32 与 AMP 的吞吐量。

执行步骤

  1. 创建文件 amp_throughput_test.py

    目标:简单粗暴地测试大矩阵乘法的耗时。如果 AMP 提速不明显(<1.5x),说明有猫腻。

    hljs python
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    import torch
    import time
    
    try:
        import torch_musa
        device = torch.device("musa")
    except ImportError:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    def benchmark(dtype, size=8192, use_autocast=False):
        a = torch.randn(size, size, device=device, dtype=torch.float32)
        b = torch.randn(size, size, device=device, dtype=torch.float32)
        
        # 预热
        for _ in range(5):
            _ = torch.mm(a, b)
        torch.cuda.synchronize() if device.type == 'cuda' else None # MUSA 也需要 sync
    
        start = time.time()
        steps = 20
        
        if use_autocast:
            # 模拟 AMP 环境
            with torch.autocast(device_type=device.type, dtype=dtype):
                for _ in range(steps):
                    c = torch.mm(a, b)
        else:
            # 强制转换类型测试纯算力
            a = a.to(dtype)
            b = b.to(dtype)
            for _ in range(steps):
                c = torch.mm(a, b)
                
        # 摩尔线程 MUSA 同样需要同步流来确保计时准确
        if torch.cuda.is_available(): torch.cuda.synchronize()
        elif hasattr(torch, 'musa'): torch.musa.synchronize()
        
        end = time.time()
        avg_time = (end - start) / steps
        tflops = (2 * size**3) / (avg_time * 1e12)
        return avg_time, tflops
    
    print(f"Benchmarking on {device}...")
    
    # 1. Baseline FP32
    t32, tflops32 = benchmark(torch.float32)
    print(f"FP32: {t32*1000:.2f} ms/step | {tflops32:.2f} TFLOPS")
    
    # 2. FP16 (Pure) - 只有硬件真支持 FP16 指令才快
    t16, tflops16 = benchmark(torch.float16)
    print(f"FP16: {t16*1000:.2f} ms/step | {tflops16:.2f} TFLOPS")
    
    # 3. Ratio
    print(f"Speedup Ratio (FP16/FP32): {t32/t16:.2f}x")
    
  2. 运行与观察

    hljs bash
    1
    python amp_throughput_test.py
    
  3. 结果锐评

    • N 厂现象:在 RTX 4090 上,FP16 速度通常是 FP32 的 2倍以上(Tensor Core 加持)。
    • 国产挑战:如果在你的国产卡上 Speedup Ratio 只有 1.0x - 1.2x,说明两个问题:
      1. 显存带宽不仅没省,反而被 Cast 操作吃掉了。
      2. 计算单元的 FP16 流水线没填满,或者干脆就是用 FP32 单元模拟的。

3. 实验总结与核心知识点

3.1【核心结论】

硬件决定下限(指令集是否支持),驱动决定上限(算子融合能力)。BF16 是 AI 训练的“版本答案”,因为它用精度换取了无需 Scaler 的动态范围,极大地降低了软件栈的适配难度。

3.2【技术解剖:MUSA 视角的 AMP】

  1. Tensor Core (N厂) vs. AI Core (MUSA): N 卡的 AMP 之所以快,是因为它在硬件物理层面上就把乘法(FP16)和累加(FP32)融合了。如果我们国产卡的驱动编译器做不到这种“微操”,频繁的数据搬运(Load FP16 -> Convert FP32 -> Compute -> Convert FP16 -> Store)会抵消掉所有优势。

  2. GradScaler 的遮羞布作用: Loss Scaling 本质上是在修补 FP16 指数位不足的缺陷。对于新架构开发来说,BF16 的支持优先级远高于 FP16,因为 BF16 能让开发者少写很多处理 NaN 的 try-except 代码,这对生态迁移至关重要。

3.3【关键概念 (Knowledge Points)】

  • Dynamic Range (动态范围):指数值能表示的最大值和最小值的区间。FP16 的范围太窄,容易被深度学习中常见的 $e^-10$ 这种小梯度击穿。
  • Underflow (下溢):梯度数值小于浮点数能表示的最小正数,直接变成 0。这是训练“假死”的元凶。
  • Mixed Precision (混合精度):不是全盘 FP16,而是“计算密集型用 FP16(卷积、MatMul),数值敏感型用 FP32(Softmax, Sum)”。PyTorch 的 autocast 就是在这个黑白名单之间做调度员。