According to this page Automatic Mixed Precision package - torch.amp — PyTorch 2.1 documentation, we have to scale the loss so that the gradient can be represented without underflow. I get this part, but when we call unscale_, it would essentially undo-the scaling as the name suggests. So wouldn’t unscale_ make the gradient underflow again?
Great question. To take the example of a linear, the weight is implicitly broadcasted in order to bmm with the input. Hence, there’s the corresponding accumulation/reduction/summation happening when you gradient of the weight using the gradient of the output of linear. Since we unscale after doing reduction, the underflow would not happen (because the reduction would increase the magnitude of the values).
Thanks for the reply! However, I am still a bit confused, let’s just say our batch_size=1 and the network is only one linear layer. Then how would the scaler work in this case? And moreover, what is the difference between accumulation and summation?
The other aspect is one which I maybe wouldn’t describe as “underflow”? I think it would be better thought of as “avoiding executing the function in a regime where you have lower precision”. This is idea that there is value in performing the function in higher precision and then unscaling versus performing the function in low precision and. And this effect can be compounding if you have many layers.