So, grad_input is the gradient of the loss w.r.t the layer input.
I see you’re using torch.float16 the max value of the bit-width is 6.55 × 10^4, you stated that AMP scales the loss to 600,000. You’re overflowing your loss, and that’s where the Inf is coming from!