Continue backward from tensor.grad

Hi, I’m using another third-party code where the authors define a input x as nn.Parameter and computes the grad himself (without using loss.backward() for sparsity). So the code looks like

x = nn.Parameter(..., requires_grad=True)
# then come custom ops
# not loss.backward() here
sparse_grad()
optim = optim.RMSprop([x])
optim.step()
optim.zero_grad()

I checked x after the sparse_grad function, x.grad indeed gave me the correct output. And this should be correct as we are replacing torch’s backward with my own backward func. Now I want to continue the computation from the gradients of x to some other input tensor. Is it possible in torch?

Many thanks in advance!

You could call backward() on intermediate tensors, but since x is created as an nn.Parameter I don’t know which “other input tensor” would be connected to it. Could you describe your use case a bit more?

Hi, @ptrblck here x is a 2d tensor, say 1000*1 dim. I think we can only backward from a 1x1 tensor right?

My use case looks somewhat like

input = cv2.imread()
fc = nn.Linear()
x = fc(input)

Then the package calculates x’s grad for me. But I wish to optimize the parameters of the fc layer as well. I can currently think of something like

loss = x*x.grad
loss = loss.sum()
loss.backward()

But I’m not sure this is the correct way.

The tensor.backward() operation can be called without passing an explicit gradient to it as an argument if the tensor is a scalar tensor. If the tensor has more than a single element you would have to pass the gradient to backward. You could use loss = x.sum() to calculate the gradient in fc if you would like to use x.sum() as the loss.

Many thanks! @ptrblck . Never thought we can write backwards this way.
Just want to make sure.

x = nn.Parameter(...)
y = x**2

And I manage to calculate y’s grad with another package and want to update x.
So I can write as, correct?

y.backward(y.grad)

Yes, this should work as seen here:

x = nn.Parameter(torch.randn(2, 2))
y = x**2

y.backward(torch.ones_like(y) * 2)
print(x.grad)
print(2*x*2)
1 Like