I have a PyTorch script involving division ops which worked earlier for torch.float32.
A = model(X[0])
B = model(X[1])
C = A/B
However, when I switched to torch.float16 by doing:
model = model.half()
A = model(X[0].to(torch.float16))
B = model(X[1].to(torch.float16))
C = A/B
I got infinity value inside C, because one element in B is
[5.9605e-08] dtype=torch.float16)]
What should I do to fix this infinity issue after switch to torch.float16 ?