Autograd problem when designing a custom loss

Hello everybody,

I am having a hard time while I am trying to design a loss function that applies Sobel filter to the batches before computing MSE. I am quite sure that the problem is related to an “autograd computational graph detachment”, but I just cannot solve it.

Here is my code. Does anyone can see what I am missing?

def sobel_MSE(output, target):
dx = (torch.tensor([[1.0, 0.0, -1.0],[2.0, 0.0, -2.0],[1.0, 0.0, -1.0]], requires_grad=True)).float()
dy = (torch.tensor([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]], requires_grad=True)).float()

dx = dx.cuda()
dy = dy.cuda()

dx = dx.view((1, 1, 3, 3))
dy = dy.view((1, 1, 3, 3))

doutdx = nn.functional.conv2d(output, dx, padding=1)
doutdy = nn.functional.conv2d(output, dy, padding=1)

dtardx = nn.functional.conv2d(target, dx, padding=1)
dtardy = nn.functional.conv2d(target, dy, padding=1)

dout = torch.sqrt(torch.pow(doutdx, 2) + torch.pow(doutdy, 2))
dtar = torch.sqrt(torch.pow(dtardx, 2) + torch.pow(dtardy, 2))

out = torch.mean(torch.pow(dout-dtar,2))

return out

Would you like to train dx and dy (since you’ve set their requires_grad attribute to True)?
If so, you shouldn’t overwrite them in these lines of code, but would need to use a new variable name and pass the original dx and dy to the optimizer:

dx = dx.cuda()
dy = dy.cuda()

dx = dx.view((1, 1, 3, 3))
dy = dy.view((1, 1, 3, 3))

On the other hand, if you don’t want to train these tensors, you could remove the requires_grad setting and as long as output is attached to the graph, the code should work.
Are you seeing any errors or unexpected behavior?

Thank you @ptrblck for you reply.

I removed the requires_grad setting to make the sobel filter static. The problem is, when I run the code - after the first loss.backward() - the network output becomes nan. I believe it was some problem when computing grad.

If I change the loss for MSE doing:

def my_mse(output, target):
out = torch.mean(torch.pow(output-target,2))

everything works fine. This way I suppose that there is no problem in the rest of the code.

Do you have any idea about what I am missing?

Thanks a lot