BFloat16 training - explicit cast vs autocast

In Pytorch, there seems to be two ways to train a model in bf16 dtype.

One is to explicitly use and to cast both input data and model to bfloat 16 format.

Another is to use torch.autocast(device_type=device, dtype=torch.bfloat16) context manager, where you don’t need to explicitly cast the input data and model to bfloat16.

Are there any difference between these two options? If yes, what’s the recommended option between the two?

Mixed-precision training via amp keeps the parameters in float32 and downcasts activations to the desired lower dtype for save operations. Calling .to() on the model directly transforms all parameters and buffers to this dtype and it’s the user’s responsibility to avoid under- and overflows.

Thanks @ptrblck. I have a follow-up question about the implication of memory usage between these options.

According to some articles, like this one (Performance and Scalability: How To Fit a Bigger Model and Train It Faster — transformers 4.7.0 documentation), training in mixed precision with fp16 actually increase the memory footprint of the model weights, as both fp32 and fp16 copies of the weights will be kept in the GPU RAM. Is this also the case for mixed precision with bf16?

That’s not the case for float16 used for mixed-precision training via torch.cuda.amp and the copies were used in the O2 training recipe in apex before amp was implemented directly in PyTorch.

Thanks @ptrblck . What you said is both interesting and somewhat suprising, because even the latest version of the HuggingFace performance tuning guide (Methods and tools for efficient training on a single GPU) still says training in fp16-based mixed precision will cause 1.5x memory usage in terms of the model weight footprint.

Given what you said, I was wondering whethter there is any benefit at all to do explicit-cast-to-bf16 instead of using autocast-to-bf16 :thinking:. I have this confusion because I noticed the latest torchtune package announced by Pytorch last month used the exliclit-cast-to-bf16 option. See torchtune/recipes/ at f3611e505dfb3a7fd1bf81181ea250e63d854e0e · pytorch/torchtune · GitHub, where they basically use torch.set_default_dtype(torch.bfloat16) context manager to explicitly cast data and model to bf16, and also see line 47 where they mentioned they do not use mixed precision training. Any idea why they would choose this option?

Explicitly casting the model to a lower dtype will reduce the model size as all parameters are transformed. Keep in mind that amp keeps all parameters in float32 and applies casts to activations (no parameter copies are created unless you are using the deprecated apex.amp module). So if the training recipe and model are stable enough explicit casts can work.

@ptrblck Aha, I think I finally understand your point. To recap,

  • old apex-based autocast to fp16/bf16 will use 1.5x memory for parameters
  • new amp-based autocast to fp16/bf16 will use 1x memory for parameters
  • explicit-cast to fp16/bf16 will use 0.5x memory for parameteters