Anomaly detection: returned nan values in its 0th output, but everything seems fine?

My code keeps crashing after a couple of thousand iterations (suddenly all the weights go to nan), but nothing obvious seemed to trigger it, so now I turned on anomaly detection and I get the following error already in the first iteration, but I can’t really see the problem.

[W ..\torch\csrc\autograd\python_anomaly_mode.cpp:60] Warning: Error detected in SqrtBackward. Traceback of forward call that caused the error:
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2020.1\plugins\python-ce\helpers\pydev\pydevd.py", line 2131, in <module>
    main()
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2020.1\plugins\python-ce\helpers\pydev\pydevd.py", line 2122, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2020.1\plugins\python-ce\helpers\pydev\pydevd.py", line 1431, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2020.1\plugins\python-ce\helpers\pydev\pydevd.py", line 1438, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2020.1\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:/Users/Tue/PycharmProjects/Pfold/run_1d_supervised.py", line 109, in <module>
    losses = main()
  File "C:\Users\Tue\PycharmProjects\Pfold\supervised\main.py", line 71, in main
    net = train(net, optimizer, dl_train, loss_fnc, dl_test=dl_test, scheduler=lr_scheduler,ite=ite_start, loss_reg_fnc=loss_reg_fnc)
  File "C:\Users\Tue\PycharmProjects\Pfold\supervised\optimization.py", line 72, in train
    dists_pred, coords_pred = net(features,mask)
  File "C:\Users\Tue\PycharmProjects\Pfold\venv\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Tue\PycharmProjects\Pfold\supervised\network_vnet.py", line 162, in forward
    dists += (tr2DistSmall(x[:,i*3:(i+1)*3,:]),)
  File "C:\Users\Tue\PycharmProjects\Pfold\supervised\network_transformer.py", line 159, in tr2DistSmall
    D = torch.sqrt(D)
 (function print_stack)
Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2020.1\plugins\python-ce\helpers\pydev\pydevd.py", line 1438, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2020.1\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:/Users/Tue/PycharmProjects/Pfold/run_1d_supervised.py", line 109, in <module>
    losses = main()
  File "C:\Users\Tue\PycharmProjects\Pfold\supervised\main.py", line 71, in main
    net = train(net, optimizer, dl_train, loss_fnc, dl_test=dl_test, scheduler=lr_scheduler,ite=ite_start, loss_reg_fnc=loss_reg_fnc)
  File "C:\Users\Tue\PycharmProjects\Pfold\supervised\optimization.py", line 91, in train
    loss.backward()
  File "C:\Users\Tue\PycharmProjects\Pfold\venv\lib\site-packages\torch\tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\Users\Tue\PycharmProjects\Pfold\venv\lib\site-packages\torch\autograd\__init__.py", line 125, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'SqrtBackward' returned nan values in its 0th output.

The function that it crashes in, is the following:

def tr2DistSmall(Y):
    k = Y.shape[1]
    Z = Y - torch.mean(Y, dim=2, keepdim=True)
    D = torch.sum(Z**2, dim=1).unsqueeze(1) + torch.sum(Z**2, dim=1).unsqueeze(2) - 2*Z.transpose(1,2) @ Z
    D = 3*D/k
    D[...,torch.arange(D.shape[-1]),torch.arange(D.shape[-1])] = 0
    D = torch.relu(D)
    D = torch.sqrt(D)
    return D

I have checked it on the forward pass, and everything looks fine, there are no nan values at any point, and all values seems reasonable. So can anyone illuminate what I’m doing wrong?

Edit:
I tried making the function it crashes in simpler, but even with the following I still get the problem:

def tr2Dist_new(r):
    d = torch.sum(r ** 2, dim=1).unsqueeze(1) + torch.sum(r ** 2, dim=1).unsqueeze(2) - 2 * r.transpose(1,2) @ r
    d[..., torch.arange(d.shape[-1]), torch.arange(d.shape[-1])] = 0
    d = torch.sqrt(d)
    return d

It seems like the issue here is that some of the values are zero, and for some reason the anomaly detection does not like to take sqrt(0), which I don’t really understand why would be a problem?

Hi,

The gradient for sqrt(0) is going to be +inf. So it is a fairly dangerous thing to have!
It might be better to compute the sqrt first and then set to 0 the values you want to zero out. That way, all the gradients will remain well defined.

The problem is that a few of the values on the diagonal are negative due to floating precision error, which is why I set them to zero before doing the sqrt operation.

But I guess I could do another masking and set all those to zero again afterwards if you think that will help with stability?

Maybe the simplest thing for you would be to set these to a small epsilon (like 1e-5).
They will have a large gradient, but it will be zeroed out by the relu/threshold as 0 * large_value = 0. Contrary to 0 * +inf = nan.

I found a way around it:

def tr2Dist_new(r):
    nb,_,n = r.shape
    d = torch.zeros((nb,n,n),dtype=r.dtype,device=r.device)
    tmp = torch.sum(r ** 2, dim=1).unsqueeze(1) + torch.sum(r ** 2, dim=1).unsqueeze(2) - 2 * r.transpose(1,2) @ r
    m = torch.ones_like(d,dtype=torch.bool,device=r.device)
    m[..., torch.arange(m.shape[-1]), torch.arange(m.shape[-1])] = False
    m2 = tmp > 0
    m3 = m*m2
    tmp2 = tmp[m3]
    d[m3] = torch.sqrt(tmp2)
    return d

But it definitely makes the code more complicated and a bit slower I guess.
I found a similar problem further on in my loss function, where I have to loop over the individual samples in a batch in order to get around it as far as I can see:

        if safeversion:
            result = 0
            for i in range(nb):
                inputi = input[i,...]
                targeti = target[i,...]
                maski = mask[i,...]
                result += torch.sum(torch.norm(inputi[maski] - targeti[maski]) ** 2 / torch.norm(targeti[maski]) ** 2)
        else:
            result = torch.sum(torch.norm(input * mask - target * mask,dim=(1,2)) ** 2 / torch.norm(target * mask,dim=(1,2)) ** 2)

It works, but it is convoluted and quite a bit slower. I’m really puzzled that other people don’t seem to have trouble with this?

Thanks for the help albanD

I would expect that code replacing the zeroing/relu in your original code with D.threshold(eps, eps) to be simpler than what you did here.

I’m really puzzled that other people don’t seem to have trouble with this?

It is fairly rare to actually evaluate functions at point where they are not differentiable (sqrt at 0). So it does not happen very often.

fair enough, I guess it is because I have been working on protein folding for too long, that I consider a matrix with a diagonal of zero to be completely normal.

And yeah I will experiment with your suggested solution as well, and see what works best in each case.
Thanks!

I think it is normal to have that. Just that taking the sqrt of that is not as much.
In general in ML, people try to avoid non-differentiable points as much as possible as it makes training much less stable :confused: