Why to keep parameters in float32, why not in (b)float16?

This thread explains the difference:

bfloat16 is more ML friendly. It gives a much wider range of values with just less precision in each stepwise difference.

Take any large model and print out your parameters:

for param in model.parameters():
    print(param)

If you’re in float32 and the model has in excess of 1 billion parameters, you will likely see many of the values are very small.

However, the smallest normal positive value float16 can have is 6.10 × 10−5. While bfloat16 can go down to 10-38. Hence float16 may require additional scaling.

Additionally, due to the nature of bfloat16 capping precision, it may act as an additional regularization to prevent some overfitting. Thus making models in bfloat16 better able to generalize.