Fused mixed precision updates with PyTorch amp

I am trying to get up to speed on the latest support for mixed precision in PyTorch. One thing that is not clear to me right now is the interaction among autocast, DistributedDataParallel and optimizer weight updates.

Suppose I am using bfloat16 (so we can ignore grad scaling for the moment) and I simply run:

with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
  loss = model(...)
loss.backward()

At that point, it looks like the .grad attributes on the model are in fp32 (as with the weights), since the autocast in forward of each weight was backprop’d through. Ideally I could:

  • Produce bfloat grads from backprop
  • (Optionally) allreduce those grads (with DDP)
  • Run a single weight update kernel that reads the bfloat grads and then updates the fp32 weights (and optimizer state)

There is clearly some support for doing things like this (eg, I see that the Adam optimizer has a bunch of logic around supporting passing fp16 scale values into the step function, which only makes sense if you are operating on fp16 grads), but it’s not clear how to reason about what values are in what types and when. Thanks!

  1. I don’t think this is directly possible since the gradients would be computed using the dtype, device, and memory layout of the corresponding parameter.

  2. You might try to use DDP comm hooks to transform the gradients before communicating them.

  3. Same as in 1: I don’t believe a dtype mismatch is supported but I let others correct me in case there are some (experimental) implementations in the wild.

Thanks for the reply, Piotr! So it sounds like if we want to (eg) perform half-precision allreduce, it’s still the case that we need to do something like the (now deprecated, I see) FP16_Optimizer – ie, explicitly put the model dtype to half precision and manually manage model and master weights at each weight update?

Hi Carl! Yes, if you want to apply the now deprecated apex.amp O2-style mixed-precision training you would still need to handle the master parameters and gradients manually.
However, you could check the experimental fsdp.MixedPrecision class which allows you to specify the reduce_dtype and e.g. keep_low_precision_grads and might come close to your use case.
Of course you would need to use FSDP so unsure if this would work for you.

Thanks, Piotr – that all makes sense. And thanks or the pointer to FSDP – interested in exactly that, actually, so I’ll play around with the MixedPrecision interface. I think a combination of FSDP along with apex’s FusedAdam implementation (which supports some degree of explicit master weight management) may get the job done.

1 Like