How to fix this nan bug?

Hi, I’m trying to modify the mean/std of one feature with the mean/std calculated from another feature. It looks like this (certain simplification is made since original code is much more complicated)


def exchange(vs, vt):
    # vs and vt are of the same size NxCxHxW
    vs_mean = torch.mean(vs, dim=(2, 3))
    vs_std = torch.std(vs, dim=(2, 3)) + self.eps
    raw_es_domain = torch.cat((vs_mean, vs_std), dim=1)
    es_domain = self.relu(self.fc1(raw_es_domain))
    es_domain = self.relu(self.fc2(es_domain))

    weight = self.fc_w(es_domain) # size of NxCx1x1
    bias = self.fc_b(es_domain) # size of NxCx1x1

    vt = nn.InstanceNorm2d(vt.size(1), affine=False)(vt)
    vt = weight * vt + bias

    return vt

However, there is this nan bug that happens sometimes, while sometimes this code snippet works just fine. I’ve used torch.autograd.detect_anomaly() to debug, and it told me the torch.std operation was the culprit. The simplified log information is listed as follows:

/opt/conda/conda-bld/pytorch_1573049306803/work/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
vs_std = torch.std(vs, dim=(2, 3)) + self.eps
loss.backward()
File "/home/user/anaconda3/envs/pytorch-1.3.1/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/user/anaconda3/envs/pytorch-1.3.1/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag

But I’ve already added an eps to the vs_std. Is there any way to fix this bug? Or is there any other way to find out what’s the cause of nan? Thanks!

This is likely because vs is constant:

a = torch.zeros(10, requires_grad=True)                                                                                                                                                                    
a.std().backward()                                                                                                                                                                                         
a.grad                                                                                                                                                                                                     

gives
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])

What you could do is add a gradient hook and modify the gradient to replace the nan with 0 (using torch.where).

Best regards

Thomas

1 Like

Thanks! Thomas

That is indeed the case for me. I checked the vs_std tensor and found something like vs[14, 205, :, :] was an all-zero tensor. So I assmue modify the nan gradient of vs with 0 would solve the problem.

However, I’ve tried to add gradient hook to both vs and vt (or vs_std and vt_std), and it didn’t work.

def hook_fn_backward(grad):
    grad = torch.where(torch.logical_not(torch.isnan(grad)), grad, torch.zeros_like(grad).to(grad.device))
return grad
vs.register_hook(hook_fn_backward)
vs_mean.register_hook(hook_fn_backward)
vs_std.register_hook(hook_fn_backward)

It still gives following error
RuntimeError: Function 'StdBackward1' returned nan values in its 0th output.

The strange thing is, I tried to set trace in hook_fn_backward, but I only found that torch.any(torch.isnan(grad)) == False. So where should I add this hook exactly?

After Further debugging, I find that add a gradient hook to vs and modify the gradient to replace the nan with 0 does solve the problem mentioned above. That is to say, the nan gradient from torch.std() is replaced with 0.

However, I then found there is another nan bug in this code. And since I’m using torch.autograd.detect_anomaly() to find out which line is the culprit, the line
vs_std = torch.std(vs, dim=(2, 3)) + self.eps
always shows up in the error log. It seems that torch.autograd.detect_anomaly() still gives an error at this line even after adding a gradient hook to vs (in fact, I have to comment out torch.autograd.detect_anomaly() to make the code run normally). And that covers up the trace of the second nan bug.

What a code snippet! I never thought of another nan bug, so naturally I’ve wasted lots of time.

Here is the thing: I’ve added gradient hooks to every module in this custom block, and I find the linear layers are also having nan gradient. There are many linear layers receiving nan grad_out and sending grad_in, but there is also this one ‘linear’ layer ‘fc_w’ receving normal ‘grad_out’ and sending nan ‘grad_in’. At first, I thought it was the input es_domain was nan. But then I found out es_domain became nan at iteration 1, while nan gradient in linear layer happened at iteration 0.

So I wonder how to fix this bug? Or is there any way to inpect the internal backprop procedure in a linear layer?

Thanks again!

@tom @ptrblck @albanD Hi, sorry for the tagging thing, but I didn’t receive any response, so i have to tag you.

Like I said, after addin a gradient hook to vs, the gradient in backbone (all those conv-bn-relu layers) are now normal. But I find the gradient of linear layers in this custom block contains nan (this block is inserted in the middle stage of backbone, like after stage 3 of ResNet or relu3-1 of VGGNet). After optimization step, the weight of these linear layers become nan. And in the next iteration, the model output become nan too.

 for name, param in self.model.backbone.named_parameters():
    print(name, torch.isnan(param.grad).any())
###### output
conv1.weight tensor(False, device='cuda:0')
bn1.weight tensor(False, device='cuda:0')
bn1.bias tensor(False, device='cuda:0')
layer1.0.conv1.weight tensor(False, device='cuda:0')
layer1.0.bn1.weight tensor(False, device='cuda:0')
layer1.0.bn1.bias tensor(False, device='cuda:0')
layer1.0.conv2.weight tensor(False, device='cuda:0')
layer1.0.bn2.weight tensor(False, device='cuda:0')
layer1.0.bn2.bias tensor(False, device='cuda:0')
layer1.1.conv1.weight tensor(False, device='cuda:0')
layer1.1.bn1.weight tensor(False, device='cuda:0')
layer1.1.bn1.bias tensor(False, device='cuda:0')
layer1.1.conv2.weight tensor(False, device='cuda:0')
layer1.1.bn2.weight tensor(False, device='cuda:0')
layer1.1.bn2.bias tensor(False, device='cuda:0')
layer2.0.conv1.weight tensor(False, device='cuda:0')
layer2.0.bn1.weight tensor(False, device='cuda:0')
layer2.0.bn1.bias tensor(False, device='cuda:0')
layer2.0.conv2.weight tensor(False, device='cuda:0')
layer2.0.bn2.weight tensor(False, device='cuda:0')
layer2.0.bn2.bias tensor(False, device='cuda:0')
layer2.0.downsample.0.weight tensor(False, device='cuda:0')
layer2.0.downsample.1.weight tensor(False, device='cuda:0')
layer2.0.downsample.1.bias tensor(False, device='cuda:0')
layer2.1.conv1.weight tensor(False, device='cuda:0')
layer2.1.bn1.weight tensor(False, device='cuda:0')
layer2.1.bn1.bias tensor(False, device='cuda:0')
layer2.1.conv2.weight tensor(False, device='cuda:0')
layer2.1.bn2.weight tensor(False, device='cuda:0')
layer2.1.bn2.bias tensor(False, device='cuda:0')
layer3.0.conv1.weight tensor(False, device='cuda:0')
layer3.0.bn1.weight tensor(False, device='cuda:0')
layer3.0.bn1.bias tensor(False, device='cuda:0')
layer3.0.conv2.weight tensor(False, device='cuda:0')
layer3.0.bn2.weight tensor(False, device='cuda:0')
layer3.0.bn2.bias tensor(False, device='cuda:0')
layer3.0.downsample.0.weight tensor(False, device='cuda:0')
layer3.0.downsample.1.weight tensor(False, device='cuda:0')
layer3.0.downsample.1.bias tensor(False, device='cuda:0')
layer3.1.conv1.weight tensor(False, device='cuda:0')
layer3.1.bn1.weight tensor(False, device='cuda:0')
layer3.1.bn1.bias tensor(False, device='cuda:0')
layer3.1.conv2.weight tensor(False, device='cuda:0')
layer3.1.bn2.weight tensor(False, device='cuda:0')
layer3.1.bn2.bias tensor(False, device='cuda:0')
layer4.0.conv1.weight tensor(False, device='cuda:0')
layer4.0.bn1.weight tensor(False, device='cuda:0')
layer4.0.bn1.bias tensor(False, device='cuda:0')
layer4.0.conv2.weight tensor(False, device='cuda:0')
layer4.0.bn2.weight tensor(False, device='cuda:0')
layer4.0.bn2.bias tensor(False, device='cuda:0')
layer4.0.downsample.0.weight tensor(False, device='cuda:0')
layer4.0.downsample.1.weight tensor(False, device='cuda:0')
layer4.0.downsample.1.bias tensor(False, device='cuda:0')
layer4.1.conv1.weight tensor(False, device='cuda:0')
layer4.1.bn1.weight tensor(False, device='cuda:0')
layer4.1.bn1.bias tensor(False, device='cuda:0')
layer4.1.conv2.weight tensor(False, device='cuda:0')
layer4.1.bn2.weight tensor(False, device='cuda:0')
layer4.1.bn2.bias tensor(False, device='cuda:0')
layer3_custom_block0.fc1.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc2.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_w.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_w.bias tensor(False, device='cuda:0')
layer3_custom_block0.fc_b.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_b.bias tensor(False, device='cuda:0')
layer3_custom_block1.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.bias tensor(True, device='cuda:0')

P.S. The mean/std of new vt is also calculated, which is omitted in the first post. I’ve also added gradient hook to new_vt, but the nan gradient in fc layer still exist.

def exchange(vs, vt):
    # vs and vt are of the same size NxCxHxW
    vs_mean = torch.mean(vs, dim=(2, 3))
    vs_std = torch.std(vs, dim=(2, 3)) + self.eps
    raw_es_domain = torch.cat((vs_mean, vs_std), dim=1)
    es_domain = self.relu(self.fc1(raw_es_domain))
    es_domain = self.relu(self.fc2(es_domain))

    weight = self.fc_w(es_domain) # size of NxCx1x1
    bias = self.fc_b(es_domain) # size of NxCx1x1

    vt = nn.InstanceNorm2d(vt.size(1), affine=False)(vt)
    new_vt = weight * vt + bias

    new_vt_mean = torch.mean(new_vt, dim=(2, 3))
    new_vt_std = torch.std(new_vt, dim=(2, 3)) + self.eps
    raw_new_et_domain = torch.cat((new_vt_mean, new_vt_std), dim=1)
    new_et_domain = self.relu(self.fc1(raw_new_et_domain))
    new_et_domain = self.relu(self.fc2(new_et_domain))

    return vt, es_domain, new_et_domain

The funny thing is, after I add hook to moniter the gradient of weight and bias in custom_block1, the nan in fc_w of custom_block1 just disappear, but the nan in fc_b do not!

def hook(grad):
    if torch.any(torch.isnan(grad)):
        print('fixing nan gradient')
    grad = torch.where(torch.logical_not(torch.isnan(grad)), grad, torch.zeros_like(grad).to(grad.device))
    return grad

weight.register_hook(grad)
bias.register_hook(grad)

####### after above hook registered, check the gradient of custom_block 
layer3_custom_block0.fc1.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc2.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_w.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_w.bias tensor(False, device='cuda:0')
layer3_custom_block0.fc_b.weight tensor(False, device='cuda:0')
layer3_custom_block0.fc_b.bias tensor(False, device='cuda:0')
layer3_custom_block1.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.weight tensor(False, device='cuda:0')
layer3_custom_block1.fc_w.bias tensor(False, device='cuda:0')
layer3_custom_block1.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.bias tensor(True, device='cuda:0')

I know it maybe a dummy question to ask, but how should I debug?

Well, I don’t know if the hook keeps the anomaly mode from complaining and while I don’t fully understand what you are doing, but it would seem that std still is the critical bit.
One (slow, but maybe OK for debugging) alternative to the gradient hook is to insert some noise new_vt = new_vt + 1e-6 * torch.randn_like(new_vt) in the forward.

Best regards

Thomas

Yes, it works!
After inserting some noise to both vs and new_vt as you suggested (without gradient hook), this code can now runs normally without giving nan values.

What I’m trying to do is to modify the moment from features of target domain (vt) with the moment from features of source domain (vs), and calculate some loss based on modified feature new_vt and its moment new_et_domain. In practice, there are multiple domains and each domain has its corresponding custom block. That is to say, custom_block{i}.exchange(v_{i}, v_{j}) is performed on each pair of domains (for example, 3 domains in total).

However, there is one strange thing I don’t understand. If inserting noise works, then the nan should come from the std part. But why doesn’t gradient hook work as well? I tried to only add noise to vs, and it gives this error instead:

/opt/conda/conda-bld/pytorch_1573049306803/work/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
 new_vt = weight * vt + bias

RuntimeError: Function 'MulBackward0' returned nan values in its 0th output.

Also, if I add noise to vs and gradient hook to new_vt, the bias of fc_b becomes normal

 for name, param in self.model.backbone.named_parameters():
    print(name, torch.isnan(param.grad).any())

###### only noise to vs
layer3_custom_block1.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.bias tensor(True, device='cuda:0')

###### both noise to vs and gradient hook to new_vt
layer3_custom_block1.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block1.fc_b.bias tensor(False, device='cuda:0')
layer3_custom_block2.fc1.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc2.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_w.bias tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.weight tensor(True, device='cuda:0')
layer3_custom_block2.fc_b.bias tensor(False, device='cuda:0')

So the nan gradient of new_vt comes from multiplication instead of std? ? The only reason I can see for multiplication to give nan gradient is the input is nan as well. But that doesn’t make sense since inserting noise to new_vt solves the problem. Is there another possible reason for multiplication to give nan gradient?

(P.S. I find this post https://github.com/pytorch/pytorch/issues/15288, but I’m not sure if it’s the same case)

So if you run anomaly mode, the error might be thrown between the producing of NaN and before the hook is applied. (I must admit I don’t know and in lieu of a simple snippet I can copy-paste into PyTorch, I’m not going to start experiments, sorry.)
One way around this might be to define your own autograd.Function for std instead of using the built-in one to not return NaNs in the first place.

1 Like

Yes, I agree.

I must say returning nan gradient while computing std on all constant tensor is just so strange! The natural way is to return all 0 gradient, isn’t it? Maybe PyTorch should fix it.

For others who have encountered the same problem, I write my own CustomStd like this:

class CustomStd(torch.autograd.Function):
@staticmethod
    def forward(ctx, input, dim, eps=1e-5, unbiased=True, keepdim=False):
        dev = input - input.mean(dim=dim, keepdim=True)
        ctx.save_for_backward(input)
        ctx.eps=eps
        ctx.dev = dev
        ctx.numdim = input.dim()
        ctx.numel = functools.reduce(lambda x, y: x * y, [input.size(d) for d in dim])
        if unbiased:
            ctx.numel -= 1
        ctx.std = torch.sqrt(torch.sum(dev * dev, dim=dim, keepdim=True) / ctx.numel)
        return ctx.std if keepdim else ctx.std.squeeze()

    @staticmethod
    def backward(ctx, grad_output):
        input,= ctx.saved_tensors
        grad_input = grad_output
        for i in range(grad_output.dim(), ctx.numdim):
            grad_input = grad_input.unsqueeze(i)
        grad_input = ctx.dev * (ctx.numel - 1) / (ctx.numel**2) / (ctx.std + ctx.eps) * grad_input
        return grad_input, None, None, None, None

It can pass a test like this:

input = torch.randn(16, 256, 14, 14, dtype=torch.float, requires_grad=True)
std1 = input.std(dim=(2,3))
std2 = CustomStd.apply(input, (2,3))
torch.sum(std1).backward(retain_graph=True)
grad1= input.grad
input.grad=None
torch.sum(std2).backward()
grad2 = input.grad
############### TEST
torch.allclose(std1, std2)       # True
torch.allclose(grad1, grad2, atol=1e-3)      # True

In the second test (gradient), I must set atol=1e-3 to make the test return true. So I guess there is a subtle difference between CustomStd and torch.std.
Also, now CustomStd can return 0 gradient for all constant tensor.

I don’t think it’s natural to want the gradient of std at constant vectors, but yes, you could argue that picking 0 as a subgradient is better than NaN (which is likely artificially there from writing it as var(…)**0.5).