The Problem: Pure FP16 Training
If you train the entire model in FP16, gradients underflow:
FP16 range: ±10^-5 to ±10^4
Backward pass gradients: Can be 10^-8 to 10^-6
Result: Gradients drop to zero → weights don't update!
Solution: Mixed-precision training
Standard Mixed-Precision Approach
Keep FP32 master weights, compute in FP16/BF16.
Forward pass:
weights_fp16 = weights_fp32.to(fp16)
activations_fp16 = forward_pass(inputs, weights_fp16)
loss_fp16 = loss_function(activations_fp16, targets)
Backward pass:
gradients_fp16 = backward_pass(loss_fp16)
Loss scaling (prevent gradient underflow):
scaled_loss_fp16 = loss_fp16 * 2^15 (scale up)
scaled_grads_fp16 = backward_pass(scaled_loss_fp16)
Update:
gradients_fp32 = scaled_grads_fp16 / 2^15 (scale back down)
weights_fp32 -= lr * gradients_fp32
Why Loss Scaling Works
| Step | Without Scaling | With Scaling (2^15) |
|---|---|---|
| Gradient magnitude | 1e-7 (underflow!) | 3.3 (safe) |
| FP16 represents it? | NO → 0 | YES |
| After backward | 0 (loss) | Accurate |
| After scaling back | — | 1e-7 (correct) |
Hardware Support
NVIDIA Automatic Mixed Precision (AMP)
Tensor Cores: H100 has special hardware for mixed-precision:
- TF32 (32-bit float, 19-bit effective)
- FP8 (8-bit float)
- Automatic dtype casting
Usage: Just add `torch.cuda.amp.autocast()`
Google TPU Mixed-Precision
TPU v4 computes in BF16, stores in FP32:
- BF16 matches FP32 exponent range (no underflow)
- Loss scaling less critical but still used for stability
- Compute: 2× faster than FP32
- Memory: Same (weights still FP32)
Which Layers Need Which Precision?
| Layer | Precision | Reason |
|---|---|---|
| Attention (matmul) | FP32 | Softmax needs stability |
| Linear (weights) | FP16/BF16 | Doesn't hurt accuracy |
| Layer norm | FP32 | Normalizes by variance |
| Loss | FP32 | Scaling needs headroom |
Example: PyTorch Mixed-Precision
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
with autocast(): # Forward in FP16
output = model(batch)
loss = loss_fn(output, target)
scaler.scale(loss).backward() # Backward in FP16 (scaled)
scaler.step(optimizer) # Update in FP32
scaler.update()
Day 15: Production quantization in real chips (Apple, Google, NVIDIA).