I am trying to implement fp16
training for a VQ-GAN. The issue is: NaNs
appear in the first batch when using gradient scaler. I tracked down the issue to a specific Conv2d
layer. I also saved the input coming to it which can be downloaded here:
Code to reproduce NaNs
:
from torch.nn import GroupNorm, SiLU, Conv2d
inf = torch.load("inf.pt").cuda()
gn1 = GroupNorm(16, 128, eps=1e-05, affine=True).cuda().half()
s = SiLU()
# out_channels: 64 -> 73 NaN
conv1 = Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)).cuda().half()
o_gn1 = gn1(inf)
o_s = s(o_gn1)
conv1_out = conv1(o_s)
conv1_out.min(), conv1_out.max()
(tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<MinBackward1>),
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>))
The weird part is that when I set out_channels
to anything outside the [64, 73] range, Nans
suddenly disappear. Also when I swich back to fp32
the problem does not occur. Am I not seeing something important? I would appreciate any help, thanks.
I run experiments on:
CUDA: cuda_11.7
Torch: 2.0.1