Skip to main content

AMP:Automatic Mixed Precision 技术笔记

AMP 的目标是:在不手动改模型 dtype 的情况下,让训练自动混合使用高精度与低精度,从而提升速度、降低显存,同时尽量保持数值稳定。

它主要由两部分组成:

autocast:自动决定每个 op 用什么 dtype
GradScaler:主要为 fp16 防止梯度 underflow / overflow

1. autocast 做了什么

典型写法:

with torch.autocast(devicetype="cuda", dtype=torch.bfloat16):
    logits = model(x)
    loss = F.crossentropy(logits, y)

loss.backward()
optimizer.step()
optimizer.zerograd(settonone=True)

进入 autocast 后,PyTorch 会启用一个 autocast 状态。之后每个 PyTorch op 会经过 dispatcher,dispatcher 根据该 op 的 policy 决定执行 dtype。

简化理解:

op 类型常见 autocast 行为原因
matmul / linear / conv用 bf16/fp16Tensor Cores 加速明显
attention 中的大矩阵乘法用 bf16/fp16计算量大,收益高
softmax / norm / reduction常用 fp32 或内部 fp32数值敏感
loss,如 cross entropy常保留 fp32 路径避免 loss 不稳定
普通 elementwise多数跟随输入 dtype计算成本较低

重点:autocast 不是把整个模型改成低精度,而是按 op 自动选择 dtype。

2. autocast 不会永久改变参数 dtype

假设模型参数是 fp32:

model.weight.dtype  # torch.float32

在 autocast 中:

with torch.autocast("cuda", dtype=torch.bfloat16):
    y = model(x)

PyTorch 可能临时把 linear 的输入和权重 cast 到 bf16 执行,但参数本体仍是 fp32。

model.weight.dtype  # 仍然是 torch.float32

这意味着:

  • optimizer 更新的通常还是 fp32 参数。

  • optimizer states 通常也仍是 fp32。

  • AMP 主要节省 activations 和临时计算 buffer,不一定让全部训练状态减半。

3. bf16 vs fp16

dtype动态范围精度训练稳定性
fp32最稳
fp16容易 underflow / overflow
bf16接近 fp32较粗通常比 fp16 稳

bf16 的关键优势:exponent 位数和 fp32 一样多,所以动态范围大。因此大模型训练中,bf16 通常比 fp16 更省心,也通常不需要 GradScaler

4. GradScaler 为什么主要用于 fp16

fp16 动态范围小,小梯度可能变成 0:

small grad -> underflow -> 0

GradScaler 的做法是先放大 loss:

scaledloss = loss  scale

于是 backward 得到的梯度也被放大:

scaledgrad = truegrad × scale

optimizer step 前再除回来,并检查是否出现 inf / nan

典型 fp16 写法:

scaler = torch.cuda.amp.GradScaler()

with torch.autocast(devicetype="cuda", dtype=torch.float16):
    logits = model(x)
    loss = F.crossentropy(logits, y)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zerograd(settonone=True)

内部逻辑:

1. scale(loss)
2. backward 得到 scaled gradients
3. step 前 unscale gradients
4. 检查 inf / nan
5. 如果正常:optimizer.step()
6. 如果异常:跳过 step,降低 scale

5. backward 一般不包进 autocast

推荐:

with torch.autocast("cuda", dtype=torch.bfloat16):
    loss = computeloss(model, x, y)

loss.backward()

不推荐特意写成:

with torch.autocast("cuda", dtype=torch.bfloat16):
    loss = computeloss(model, x, y)
    loss.backward()

原因:backward 的 dtype 通常由 forward graph 中记录的 op 决定。forward 用什么路径,backward 会沿着对应 graph 执行,不需要额外包 autocast。

6. 常见注意点

场景正确做法
bf16 训练通常只用 autocast(dtype=torch.bfloat16)
fp16 训练autocast(dtype=torch.float16) + GradScaler
gradient clippingfp16 下先 scaler.unscale_(optimizer),再 clip
debug 数值问题打印 tensor dtype、检查 nan/inf
评估/inference可以用 autocast,但不需要 GradScaler
手动 .half()不等价于 AMP,风险更高

gradient clipping 示例:

scaler.scale(loss).backward()
scaler.unscale(optimizer)

torch.nn.utils.clipgradnorm(model.parameters(), maxnorm=1.0)

scaler.step(optimizer)
scaler.update()

7. 最重要的 mental model

参数主副本:通常 fp32
大矩阵乘法:临时低精度,用 Tensor Cores
数值敏感 op:保留 fp32 或内部 fp32
fp16:需要 GradScaler
bf16:通常不需要 GradScaler

一句话总结:

AMP 的本质是在 PyTorch dispatcher 层面,根据 op 的性能收益和数值稳定性,自动选择执行精度;它不是简单地把模型整体转成半精度。