Skip to main content

AMP: automatic mixed precision as a dispatch policy

The easiest mistake with AMP is to think it means “train the whole model in half precision.”

That is not the right mental model. AMP is a runtime precision policy. It lets PyTorch choose lower precision for operations that benefit from it, while keeping sensitive operations in safer precision.

The goal is practical:

more throughput
less activation memory
minimal manual dtype surgery
acceptable numerical stability

The two moving parts

In PyTorch, AMP is mainly two mechanisms:

autocast: choose execution dtype per operation
GradScaler: protect fp16 gradients from underflow and overflow

A typical bf16 training step looks like this:

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    logits = model(x)
    loss = F.cross_entropy(logits, y)

loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

The important part is the boundary. autocast wraps the forward computation. Backward usually does not need its own autocast block because it follows dtype decisions recorded in the forward graph.

Autocast is a dispatcher decision

Inside autocast, each PyTorch operation goes through a policy decision. Some operations are safe and profitable in lower precision. Others are numerically sensitive.

Operation familyTypical autocast behaviorReason
matmul, linear, convbf16 or fp16Tensor Cores can make these much faster
Attention matrix multipliesbf16 or fp16high arithmetic intensity
Softmax, norm, reductionsfp32 or internal fp32numerically sensitive
Loss functionsoften fp32 pathprotects loss stability
Elementwise opsusually follows inputslower performance leverage

So AMP is not a global conversion. It is a per-op execution policy.

Parameters usually stay where they are

If model weights start as fp32, autocast does not permanently rewrite them.

model.weight.dtype  # torch.float32

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    y = model(x)

model.weight.dtype  # still torch.float32

During a linear call, PyTorch may use lower-precision temporary inputs or kernels. The parameter object itself remains fp32. Optimizer states also usually remain fp32.

This is why AMP reduces activation and temporary buffer cost more than it reduces the entire training-state footprint.

bf16 and fp16 solve different pain

The main difference is dynamic range.

dtypeDynamic rangePrecisionTraining behavior
fp32largehighmost stable
fp16smallmediumcan underflow or overflow
bf16close to fp32coarserusually easier for large models

bf16 keeps the fp32 exponent width, so it has a much larger dynamic range than fp16. That is why bf16 training often does not need a GradScaler.

fp16 is different. Small gradients may underflow to zero:

small gradient -> underflow -> 0

GradScaler works by scaling the loss before backward, then unscaling gradients before the optimizer step:

scaler = torch.cuda.amp.GradScaler()

with torch.autocast(device_type="cuda", dtype=torch.float16):
    logits = model(x)
    loss = F.cross_entropy(logits, y)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

Conceptually:

scale loss
  -> backward produces scaled gradients
  -> unscale before step
  -> check inf or nan
  -> step if safe, skip and lower scale if unsafe

Gradient clipping has one trap

For fp16 with a scaler, clip after unscaling:

scaler.scale(loss).backward()
scaler.unscale_(optimizer)

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

If you clip scaled gradients directly, the clipping threshold no longer means what you think it means.

The mental model

The useful summary is:

master parameters: usually fp32
large matmuls: temporary low precision
sensitive ops: fp32 or internal fp32
fp16: use GradScaler
bf16: usually no GradScaler
backward: follows the forward graph

AMP is a dispatch-layer precision policy. It is not the same as manually calling .half() on the model, and treating those two as equivalent is the fastest way to get confused debugging numerical issues.