Hi everyone!
Let’s assume that I have at hand:
- a function
S(x,y)
which is symmetric with respect tox
andy
- an input tensor
a
and that I want to compute the gradient of f(a) = S(a,a)
with respect to a
, plus higher order derivatives.
A standard way of doing this in PyTorch is to type:
import torch
S = lambda x,y : (-x*y).exp() # or any other symmetric function of (x,y)
a = torch.tensor([3.], requires_grad=True)
f = S(a,a) # = exp(-a*a)
g = torch.autograd.grad( f, [a], create_graph=True)[0]
# = -a*exp(-a*a) + -a*exp(-a*a) -> inefficient. We should do better!
h = torch.autograd.grad( g, [a], create_graph=True)[0]
# = exp(-a*a) * (4*a*a - 2)
print(a, f, g, h)
print( (-a*a).exp()*(4*a*a-2))
However, in theory, I could do much better: since I know that S(x,y)
is symmetric, I also know that
dS/dx(a,a) = dS/dy(a,a)
and I would thus like to compute df/da = dS(a,a)/da
as
df/da(a) = 2*dS/dx(a,a) instead of df/da(a) = dS/dx(a,a) + dS/dy(a,a),
dividing by two my compute time. Is this possible? Well, kind of. Using the standard .detach()
methods, etc., I was able to compute the right value for df/da
, but could not get a fully differentiable expression:
a = torch.tensor([3.], requires_grad=True)
b = a.detach().requires_grad_()
f = S(a,b) # = exp(-a*b)
g = torch.autograd.grad( 2*f, [b], create_graph=True)[0]
# = -2*a* exp(-a*b) -> efficient and correct, since a.data = b.data...
h = torch.autograd.grad( g, [a], create_graph=True)[0]
# = exp(-a*b) * (2*a*b - 2 ) -> This is not what we want!
# The 4 was replaced by a 2!
print(a, f, g, h)
print( (-a*b).exp()*(2*a*b-2))
Did I miss something? Is there a workaround?
Using undocumented low-level methods (?), I was thinking of replacing the references to b
in g.grad_fn
by references to a
… Is this possible?
This would be really useful to many researchers who are currently using PyTorch for shape analysis. Symmetric functions pop up all the time as high-dimensional kinetic energies, etc., and we would like to be able to backprop through their expressions without having to compute the same derivative twice.