Bfloat16 training

Hey!
I was wondering about your experience with bf16 training, specifically the type of the gradients and the weights.
From the repositories I’ve checked, it seems there are two main approaches (assuming a single GPU, no FSDP, and all activations are recomputed during backward):

  1. BF16 weights + BF16 grads + FP32 master copy + FP32 optimizer state → (2 + 2 + 4 + 8) = 16 bytes/param (as in NeMo).

  2. BF16 weights + FP32 grad accumulation + FP32 optimizer state→ (2 + 4 + 8) = 14 bytes/param (I haven’t found a reference, but it appears Megatrom-LM does this).

I then questioned the need for FP32 grad accumulation / copy of the weights so using just bf16 weights, bf16 gradients, optimizer 32 = 12 bytes and came across this paper (page 6, Section 5.1 “Precision options”), case D−MW - bf16 weights, bf16 gradients, optimizer 32 no master copy. But according to Table 3, performance is worse compared to the first option.
So I was wondering if you experienced any performance degradation if gradients, weights are in bf16 only? Why it is related to pytorch forum:

The question:

should I do:
model.to(bfloat16); x.to(bfloat16) – so no master copy, gradients are bf16 type
or
with torch.autocast('cuda', dtype=torch.bfloat16):
losses = loss_fn(model, x, y) – still have master copy

Thank you very much for your input!

Best to stick with autocast.

  1. The 2nd configuration you mentioned is not 14-bytes/param. Should be 16-bytes/param: BF16 compute weights(2 bytes) + FP32 master weights(4 bytes) + BF16 gradients(2 bytes) + FP32 optimizer states(8 bytes).
  2. As the paper you linked demonstrates, even with FP32 optimizer states, skipping a master copy degrades effective descent quality.

Yes. It will hurt performance, as a significant number of weight values become zero when cast to bf16 (fp16), as shown in the figure below.

There are various design choices to overcome this issue. For more information, you may refer tothis page on NVIDIA’s Amp. As suggested by @J_Johnson, it is preferable to use autocast.