How should use_amp be set when using torch.cuda.amp?

When using:

scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
with torch.cuda.amp.autocast(enabled=use_amp):
    ...

how should use_amp be determined?
Should it explicitly reflect whether the GPU supports AMP, or is torch.cuda.is_available() sufficient?

Also, when using DDP, what is the recommended way to handle AMP in case different GPUs in the cluster have different AMP support?
Is heterogeneous AMP across ranks supported, or should AMP be enabled/disabled globally?

Unless you are using an older Pascal GPU (or even previous generations) all NVIDIA GPUs will support AMP via float16. Ampere and newer will also support AMP in bfloat16.

I would globally enable or disable it in multi-GPU setups.

1 Like