How can I prevent infinity value after switching from float32 to float16?

I have a PyTorch script involving division ops which worked earlier for torch.float32.

A = model(X[0])
B = model(X[1])
C = A/B

However, when I switched to torch.float16 by doing:

model = model.half()
A = model(X[0].to(torch.float16))
B = model(X[1].to(torch.float16))
C = A/B

I got infinity value inside C, because one element in B is

[5.9605e-08] dtype=torch.float16)]

What should I do to fix this infinity issue after switch to torch.float16 ?

If you need more range with the same amount of bits, bfloat16 can help you here.

You can also use higher-precision temporarily in parts of the model that require more precision, and continue using lower-precision in parts of your model that would benefit, e.g. in terms of speed/memory from lower precision.

Automatic Mixed Precision package - torch.amp — PyTorch 2.5 documentation is a way to automate this.

1 Like