How to fix this nan bug?

@tom @ptrblck @albanD Hi, sorry for the tagging thing, but I didn’t receive any response, so i have to tag you.

Like I said, after addin a gradient hook to vs, the gradient in backbone (all those conv-bn-relu layers) are now normal. But I find the gradient of linear layers in this custom block contains nan (this block is inserted in the middle stage of backbone, like after stage 3 of ResNet or relu3-1 of VGGNet). After optimization step, the weight of these linear layers become nan. And in the next iteration, the model output become nan too.

 for name, param in self.model.backbone.named_parameters():
    print(name, torch.isnan(param.grad).any())
###### output
conv1.weight tensor(False, device='cuda:0')
bn1.weight tensor(False, device='cuda:0')
bn1.bias tensor(False, device='cuda:0')
layer1.0.conv1.weight tensor(False, device='cuda:0')
layer1.0.bn1.weight tensor(False, device='cuda:0')
layer1.0.bn1.bias tensor(False, device='cuda:0')
layer1.0.conv2.weight tensor(False, device='cuda:0')
layer1.0.bn2.weight tensor(False, device='cuda:0')
layer1.0.bn2.bias tensor(False, device='cuda:0')
layer1.1.conv1.weight tensor(False, device='cuda:0')
layer1.1.bn1.weight tensor(False, device='cuda:0')
layer1.1.bn1.bias tensor(False, device='cuda:0')
layer1.1.conv2.weight tensor(False, device='cuda:0')
layer1.1.bn2.weight tensor(False, device='cuda:0')
layer1.1.bn2.bias tensor(False, device='cuda:0')
layer2.0.conv1.weight tensor(False, device='cuda:0')
layer2.0.bn1.weight tensor(False, device='cuda:0')
layer2.0.bn1.bias tensor(False, device='cuda:0')
layer2.0.conv2.weight tensor(False, device='cuda:0')
layer2.0.bn2.weight tensor(False, device='cuda:0')
layer2.0.bn2.bias tensor(False, device='cuda:0')
layer2.0.downsample.0.weight tensor(False, device='cuda:0')
layer2.0.downsample.1.weight tensor(False, device='cuda:0')
layer2.0.downsample.1.bias tensor(False, device='cuda:0')
layer2.1.conv1.weight tensor(False, device='cuda:0')
layer2.1.bn1.weight tensor(False, device='cuda:0')
layer2.1.bn1.bias tensor(False, device='cuda:0')
layer2.1.conv2.weight tensor(False, device='cuda:0')
layer2.1.bn2.weight tensor(False, device='cuda:0')
layer2.1.bn2.bias tensor(False, device='cuda:0')
layer3.0.conv1.weight tensor(False, device='cuda:0')
layer3.0.bn1.weight tensor(False, device='cuda:0')
layer3.0.bn1.bias tensor(False, device='cuda:0')
layer3.0.conv2.weight tensor(False, device='cuda:0')
layer3.0.bn2.weight tensor(False, device='cuda:0')
layer3.0.bn2.bias tensor(False, device='cuda:0')
layer3.0.downsample.0.weight tensor(False, device='cuda:0')
layer3.0.downsample.1.weight tensor(False, device='cuda:0')
layer3.0.downsample.1.bias tensor(False, device='cuda:0')
layer3.1.conv1.weight tensor(False, device='cuda:0')
layer3.1.bn1.weight tensor(False, device='cuda:0')
layer3.1.bn1.bias tensor(False, device='cuda:0')
layer3.1.conv2.weight tensor(False, device='cuda:0')
layer3.1.bn2.weight tensor(False, device='cuda:0')
layer3.1.bn2.bias tensor(False, device='cuda:0')
layer4.0.conv1.weight tensor(False, device='cuda:0')
layer4.0.bn1.weight tensor(False, device='cuda:0')
layer4.0.bn1.bias tensor(False, device='cuda:0')
layer4.0.conv2.weight tensor(False, device='cuda:0')
layer4.0.bn2.weight tensor(False, device='cuda:0')
layer4.0.bn2.bias tensor(False, device='cuda:0')
layer4.0.downsample.0.weight tensor(False, device='cuda:0')
layer4.0.downsample.1.weight tensor(False, device='cuda:0')
layer4.0.downsample.1.bias tensor(False, device='cuda:0')
layer4.1.conv1.weight tensor(False, device='cuda:0')
layer4.1.bn1.weight tensor(False, device='cuda:0')
layer4.1.bn1.bias tensor(False, device='cuda:0')
layer4.1.conv2.weight tensor(False, device='cuda:0')
layer4.1.bn2.weight tensor(False, device='cuda:0')
layer4.1.bn2.bias tensor(False, device='cuda:0')
layer3_custom_block0.fc1.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc2.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_w.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_w.bias tensor(False, device='cuda:0')
layer3_custom_block0.fc_b.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_b.bias tensor(False, device='cuda:0')
layer3_custom_block1.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.bias tensor(True, device='cuda:0')

P.S. The mean/std of new vt is also calculated, which is omitted in the first post. I’ve also added gradient hook to new_vt, but the nan gradient in fc layer still exist.

def exchange(vs, vt):
    # vs and vt are of the same size NxCxHxW
    vs_mean = torch.mean(vs, dim=(2, 3))
    vs_std = torch.std(vs, dim=(2, 3)) + self.eps
    raw_es_domain = torch.cat((vs_mean, vs_std), dim=1)
    es_domain = self.relu(self.fc1(raw_es_domain))
    es_domain = self.relu(self.fc2(es_domain))

    weight = self.fc_w(es_domain) # size of NxCx1x1
    bias = self.fc_b(es_domain) # size of NxCx1x1

    vt = nn.InstanceNorm2d(vt.size(1), affine=False)(vt)
    new_vt = weight * vt + bias

    new_vt_mean = torch.mean(new_vt, dim=(2, 3))
    new_vt_std = torch.std(new_vt, dim=(2, 3)) + self.eps
    raw_new_et_domain = torch.cat((new_vt_mean, new_vt_std), dim=1)
    new_et_domain = self.relu(self.fc1(raw_new_et_domain))
    new_et_domain = self.relu(self.fc2(new_et_domain))

    return vt, es_domain, new_et_domain

The funny thing is, after I add hook to moniter the gradient of weight and bias in custom_block1, the nan in fc_w of custom_block1 just disappear, but the nan in fc_b do not!

def hook(grad):
    if torch.any(torch.isnan(grad)):
        print('fixing nan gradient')
    grad = torch.where(torch.logical_not(torch.isnan(grad)), grad, torch.zeros_like(grad).to(grad.device))
    return grad

weight.register_hook(grad)
bias.register_hook(grad)

####### after above hook registered, check the gradient of custom_block 
layer3_custom_block0.fc1.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc2.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_w.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_w.bias tensor(False, device='cuda:0')
layer3_custom_block0.fc_b.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_b.bias tensor(False, device='cuda:0')
layer3_custom_block1.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.weight tensor(False, device='cuda:0')
layer3_custom_block1.fc_w.bias tensor(False, device='cuda:0')
layer3_custom_block1.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.bias tensor(True, device='cuda:0')

I know it maybe a dummy question to ask, but how should I debug?