Second order gradient zeroing on different shape Tensor

Hi, I’m trying to create a Graph based model to learn on unstructured data using torch and torch_geometric in which my loss function will depend on 1st and 2nd order derivatives.

Within the model I use my points 3D coordinates to compute edge weights from the distances between them. The problem I’m having is that I need to compute a second order gradient w.r.t coordinates, I manage to obtain the first order gradient but not the second one. Here a minimal code to reproduce the issue:

import torch

# coordinates
x = torch.tensor( [[0.0,0.0,0.0],
[1.0,0.0,0.0],
[0.0,1.0,0.0],
[0.0,0.0,1.0],
[1.0,1.0,1.0]], requires_grad=True)

# Graphs edge connectivity
edges = torch.tensor(
[[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4],
[1, 2, 3, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 1, 2, 3]])

# Compute weight
row, col = edges
dist_sq = torch.sum((x[col] - x[row]).pow(2), dim=1)
weight = dist_sq.pow(-1)

# Compute 1st order derivative: This will have some non-null outputs
ones = torch.ones_like(weight)
qx = torch.autograd.grad(weight, x, grad_outputs=ones, create_graph=True)[0]
print(qx)

# Compute 2nd order derivative: Gradient will be null
# also a user-warning might be prompted regarding the use of .grad attribute
# on a non-leaf tensor which I tried to fix by forcing retain_grad() but nothing changed
ones = torch.ones_like(qx)
qxx = torch.autograd.grad(qx, x, grad_outputs=ones, create_graph=True)[0]
print(qxx)

If instead of computing the weight Tensor as an “edge attribute” I compute a “nodal” attribute by doing, say, weight = x.pow(2) I will get the proper gradients… but this is not what I’m looking for though… So far my only guess is that the shape change (from nodal to edge) might be posing a problem for the computation of the gradient.

Any ideas?

Thanks in advance!

Hi,

I would say that most likely gradients cancel out and are actually 0?
In particular doing x[col] - x[row] might be cancelling out all the gradients no?

I was expecting to avoid the cancelling problem by using the square of the distances, and at the end try to compute the trace of the Hessian matrix at each node but I might be missing something.

Also, if the gradients were to cancel, wouldn’t I get also zero for the first gradient computation? I’m obtaining:
[[ 4., 4., 4.],
[-6., 2., 2.],
[ 2., -6., 2.],
[ 2., 2., -6.],
[-2., -2., -2.]]

Could you write down on paper what the gradient and second order gradient be for this case? (with a small x and less edges)
Does that match what you expect?

Note that the function passes gradcheck. So the finite difference agrees with pytorch that the gradient should be 0 here:

import torch

# coordinates
x = torch.tensor( [[0.0,0.0,0.0],
[1.0,0.0,0.0],
[0.0,1.0,0.0],
[0.0,0.0,1.0],
[1.0,1.0,1.0]], dtype=torch.double, requires_grad=True)

def fn(x):
    # Graphs edge connectivity
    edges = torch.tensor(
    [[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4],
    [1, 2, 3, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 1, 2, 3]])

    # Compute weight
    row, col = edges
    dist_sq = torch.sum((x[col] - x[row]).pow(2), dim=1)
    weight = dist_sq.pow(-1)

    # Compute 1st order derivative: This will have some non-null outputs
    ones = torch.ones_like(weight)
    qx = torch.autograd.grad(weight, x, grad_outputs=ones, create_graph=True)[0]
    return qx

torch.autograd.gradcheck(fn, x) # Ensure that grad(fn(x), x) is correct.

Hi @albanD, thanks a lot! So yes I was forgetting a big detail, since my function depends simultaneously on x[col] and x[row] the first derivative at any x[i] will contain values canceling out for the neighboring pairs (i,j) and only values for (j,k) not containing (i), so the second derivative will always be zero…

Since what I want to obtain is a function that considers the x[i] coordinate as the local variable the simplest solution I found was to copy x into a detached variable:

Simpler 1D test with 3 nodes, 2 edges:

x = torch.tensor( [[0.0],[1.0],[2.0]], requires_grad=True)
y = x.clone().detach()

edges = torch.tensor([[0, 1], [1, 2]])

dist_sq = torch.sum( (y[col] - x[row]).pow(2) , dim = 1)
weight = dist_sq.pow(-1)

# Compute first gradient 
ones = torch.ones_like(weight)
qx = torch.autograd.grad(weight, x, grad_outputs=ones, create_graph=True, retain_graph=True)[0]
print(qx)

# Compute second gradient 
ones = torch.ones_like(qx)
qxx = torch.autograd.grad(qx, x, grad_outputs=ones, create_graph=True, retain_graph=True)[0]
print(qxx)

Now I manage to obtain my second derivative!

Thanks again!

P.S.: I still get though the UsreWarning (in debug mode)
“UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won’t be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor …”

It seems to be associated to the creation of the dist_sq tensor. I tried doing dist_sq.retain_grad() but same thing … Since I manage to compute the gradients I’m not sure if this could cause an error or miscalculation during back propagation for the training !?

Perfect!

P.S.: I still get though the UsreWarning (in debug mode)

This is weird. I don’t see any warning when running your code on colab.
Can you run your code with python’s argument -W error to make it raise an error for the warning so that you can get a full stack trace of where it happens?