Can GradScaler and autocast be safely unified with enabled flag for both AMP and non-AMP training

Hi everyone,

I’m currently maintaining two separate training loops for my PyTorch model: one standard (float32) and one using automatic mixed precision (AMP) with torch.amp.

To simplify my codebase, I would like to merge them into a single training loop using the enabled flag for both torch.amp.GradScaler and torch.autocast.

Here’s my question:

If I use this pattern:

scaler = torch.amp.GradScaler(enabled=config.enable_amp)

for x, y in train_loader:
    with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=config.enable_amp):
        logits = model(x)
        loss = criterion(logits, y)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

When config.enable_amp is False, will the behavior be exactly equivalent to standard full-precision training, i.e.:

  • scaler.scale(loss).backward() → behaves like loss.backward()
  • scaler.step(optimizer) → behaves like optimizer.step()
  • scaler.update() → becomes a no-op

Is this the correct and safe way to generalize the training loop for both AMP and non-AMP modes without duplicating logic?

Thanks in advance!

Yes, this should be the case if you properly pass the enabled argument to GradScaler and autocast.

E.g. scale() will just return the passed outputs and will be a no-op, step() will just call optimizer.step() is it’s not enabled, etc.

1 Like

Also, a quick follow-up question:

Is it OK if my validation loop does not use mixed precision (i.e., runs in full float32), even when training uses AMP? Or should I also wrap the validation step in:

with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=config.enable_amp):
    logits = model(x)
    loss = criterion(logits, y)

The reason I ask is because I use the validation loss for early stopping, so I want to make sure that not using AMP in validation won’t invalidate or misalign the early stopping behavior.

Thanks again!

Yes, you can use FP32 in the validation loop if needed.