Autograd problem when designing a custom loss

Hello everybody,

I am having a hard time while I am trying to design a loss function that applies Sobel filter to the batches before computing MSE. I am quite sure that the problem is related to an “autograd computational graph detachment”, but I just cannot solve it.

Here is my code. Does anyone can see what I am missing?

def sobel_MSE(output, target):
dx = (torch.tensor([[1.0, 0.0, -1.0],[2.0, 0.0, -2.0],[1.0, 0.0, -1.0]], requires_grad=True)).float()
dy = (torch.tensor([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]], requires_grad=True)).float()

dx = dx.cuda()
dy = dy.cuda()

dx = dx.view((1, 1, 3, 3))
dy = dy.view((1, 1, 3, 3))

doutdx = nn.functional.conv2d(output, dx, padding=1)
doutdy = nn.functional.conv2d(output, dy, padding=1)

dtardx = nn.functional.conv2d(target, dx, padding=1)
dtardy = nn.functional.conv2d(target, dy, padding=1)

dout = torch.sqrt(torch.pow(doutdx, 2) + torch.pow(doutdy, 2))
dtar = torch.sqrt(torch.pow(dtardx, 2) + torch.pow(dtardy, 2))

out = torch.mean(torch.pow(dout-dtar,2))

return out

Would you like to train dx and dy (since you’ve set their requires_grad attribute to True)?
If so, you shouldn’t overwrite them in these lines of code, but would need to use a new variable name and pass the original dx and dy to the optimizer:

dx = dx.cuda()
dy = dy.cuda()

dx = dx.view((1, 1, 3, 3))
dy = dy.view((1, 1, 3, 3))

On the other hand, if you don’t want to train these tensors, you could remove the requires_grad setting and as long as output is attached to the graph, the code should work.
Are you seeing any errors or unexpected behavior?

Thank you @ptrblck for you reply.

I removed the requires_grad setting to make the sobel filter static. The problem is, when I run the code - after the first loss.backward() - the network output becomes nan. I believe it was some problem when computing grad.

If I change the loss for MSE doing:

def my_mse(output, target):
out = torch.mean(torch.pow(output-target,2))

everything works fine. This way I suppose that there is no problem in the rest of the code.

Do you have any idea about what I am missing?

Thanks a lot

Could you rerun the code with torch.autograd.set_detect_anomaly(True) at the beginning of the script and post the stack trace here?
Based on the description I assume the loss does not contain any invalid values?
If that’s the case, could you check all gradients after the first backward() pass for NaN values?

Thanks again @ptrblck!

By enablig the anomaly detection, I found the following error trace:

..\torch\csrc\autograd\python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
 File "C:/Users/Luis/PycharmProjects/unet-googlenet/train.py", line 291, in <module>
   train()
 File "C:/Users/Luis/PycharmProjects/unet-googlenet/train.py", line 214, in train
   loss = sobel_MSE(outputs, labels)
 File "C:/Users/Luis/PycharmProjects/unet-googlenet/train.py", line 142, in sobel_MSE
   dout = torch.sqrt(torch.pow(doutdx, 2) + torch.pow(doutdy, 2))

Traceback (most recent call last):
 File "C:/Users/Luis/PycharmProjects/unet-googlenet/train.py", line 291, in <module>
   train()
 File "C:/Users/Luis/PycharmProjects/unet-googlenet/train.py", line 216, in train
   loss.backward()
 File "C:\Users\Luis\Anaconda3\envs\unet-goolenet\lib\site-packages\torch\tensor.py", line 166, in backward
   torch.autograd.backward(self, gradient, retain_graph, create_graph)
 File "C:\Users\Luis\Anaconda3\envs\unet-goolenet\lib\site-packages\torch\autograd\__init__.py", line 99, in backward
   allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'SqrtBackward' returned nan values in its 0th output.

Does it means that the gradient calculated by torch.srqt was nan at the following line?

dout = torch.sqrt(torch.pow(doutdx, 2) + torch.pow(doutdy, 2))

Yes!

I replaced

dout = torch.sqrt(torch.pow(doutdx, 2) + torch.pow(doutdy, 2))

by

dout = (torch.pow(doutdx, 2) + torch.pow(doutdy, 2))

and now it seems to compute the backprojection fine! By removing the sqrt we do not lose the edge detection semantic, but now I am curious to understand why it is happening.

Do you have any idea?

torch.sqrt() would give you NaN gradients for negative inputs, which shouldn’t be possible given your code snippet and an Inf gradient for a zero input, which might be the case for your use case:

x = torch.tensor([0.], requires_grad=True)
y = torch.sqrt(x)
y.backward()
print(x.grad)
> tensor([inf])

You could add a small eps value to the sqrt op to avoid this. However, note that this might blow up your gradients due to the behavior of the derivative of the sqrt.

1 Like