When we using AMP, we found that the running_mean and runnning_var of self.shared_2 are normal before receiving x1. After producing x2, however, there are some NANs in its running_mean and runnning_var. Maybe these values calculated from x1 exceed the range of half precision?
Are there any elegant approach to avoid NAN? Thanks
For instance, how could we keep above conv layer as FP32?
This should not be the case since autocast will use float32 in batchnorm layers.
Most likely x1 already contained invalid values (Infs or NaNs) and is thus updating the running stats with invalid values, too.
Did you make sure the input does not contain any invalid values?
We check the min, max, mean values of x1 which are all normal.
Besides, the eps values in all codes are set to 1e-3, as well as the torch.log(). So, does the calculation of running_mean and running_var in bn is still fp16, which exceeds range?
No, the batchnorm layers will use float32 inside the autocast context as already explained. Unless you are calling half() on this model and use a manual flaot16 training, this operation should not overflow.
We know that the autocast(enable=False) could be used in forward() to set specific layers as FP32. Are there some approaches to keep specific layers as FP32 in __init__ function?
Maybe we could find bug layer by progressively replacing FP16 layers as FP32…