跳到正文

AMP:自动混合精度作为调度策略

AMP 最容易犯的错误是认为它意味着“以半精度训练整个模型”。

这不是正确的思维模式。 AMP 是一种运行时精度策略。它允许 PyTorch 为从中受益的操作选择较低的精度,同时将敏感操作保持在更安全的精度。

目标很实际:

more throughput
less activation memory
minimal manual dtype surgery
acceptable numerical stability

两个活动部件

在 PyTorch 中,AMP 主要是两种机制:

autocast: choose execution dtype per operation
GradScaler: protect fp16 gradients from underflow and overflow

典型的 bf16 训练步骤如下所示:

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
 logits = model(x)
 loss = F.cross_entropy(logits, y)

loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

重要的部分是边界。 autocast 包装了前向计算。 Backward 通常不需要自己的 autocast 块,因为它遵循前向图中记录的 dtype 决策。

Autocast 是调度员的决定

autocast 内部,每个 PyTorch 操作都会经历一个策略决策。有些操作在较低精度下是安全且有利可图的。其他的则对数字敏感。

操作家族典型的自动施放行为原因
matmullinearconvbf16 或 fp16张量核心可以使这些速度更快
注意力矩阵乘法bf16 或 fp16高算术强度
Softmax、范数、缩减fp32 或内部 fp32数字敏感
损失函数常 fp32 路径保护损失稳定性
逐元素运算通常遵循输入较低的绩效杠杆

所以 AMP 不是全局转换。这是每操作执行策略。

参数通常保留在原来的位置

如果模型权重以 fp32 开头,则 autocast 不会永久重写它们。

model.weight.dtype # torch.float32

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
 y = model(x)

model.weight.dtype # still torch.float32

linear 调用期间,PyTorch 可能会使用较低精度的临时输入或内核。参数对象本身仍然是 fp32。优化器状态通常也保留为 fp32。

这就是为什么 AMP 减少的激活和临时缓冲区成本比减少整个训练状态占用空间要多。

bf16 和 fp16 解决不同的痛点

主要区别在于动态范围。

数据类型动态范围精密训练行为
FP32最稳定
FP16中等可以下溢或溢出
BF16接近 fp32较粗通常对于大型模型来说更容易

bf16 保留了 fp32 指数宽度,因此它的动态范围比 fp16 大得多。这就是为什么 bf16 训练通常不需要GradScaler

fp16 则不同。小梯度可能下溢到零:

small gradient -> underflow -> 0

GradScaler 的工作原理是在向后缩放之前缩放损失,然后在优化器步骤之前取消缩放梯度:

scaler = torch.cuda.amp.GradScaler()

with torch.autocast(device_type="cuda", dtype=torch.float16):
 logits = model(x)
 loss = F.cross_entropy(logits, y)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

从概念上讲:

scale loss
 -> backward produces scaled gradients
 -> unscale before step
 -> check inf or nan
 -> step if safe, skip and lower scale if unsafe

梯度裁剪有一个陷阱

对于带有缩放器的 fp16,取消缩放后进行剪辑:

scaler.scale(loss).backward()
scaler.unscale_(optimizer)

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

如果直接剪切缩放渐变,则剪切阈值不再具有您认为的含义。

心智模型

有用的总结是:

master parameters: usually fp32
large matmuls: temporary low precision
sensitive ops: fp32 or internal fp32
fp16: use GradScaler
bf16: usually no GradScaler
backward: follows the forward graph

AMP 是调度层精准策略。它与在模型上手动调用 .half() 不同,将这两者视为等效是混淆调试数值问题的最快方法。