Hi everyone,
I’m currently maintaining two separate training loops for my PyTorch model: one standard (float32) and one using automatic mixed precision (AMP) with torch.amp
.
To simplify my codebase, I would like to merge them into a single training loop using the enabled
flag for both torch.amp.GradScaler
and torch.autocast
.
Here’s my question:
If I use this pattern:
scaler = torch.amp.GradScaler(enabled=config.enable_amp)
for x, y in train_loader:
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=config.enable_amp):
logits = model(x)
loss = criterion(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
When config.enable_amp
is False
, will the behavior be exactly equivalent to standard full-precision training, i.e.:
scaler.scale(loss).backward()
→ behaves likeloss.backward()
scaler.step(optimizer)
→ behaves likeoptimizer.step()
scaler.update()
→ becomes a no-op
Is this the correct and safe way to generalize the training loop for both AMP and non-AMP modes without duplicating logic?
Thanks in advance!