Computing the gradient of a symmetric function: graph surgery needed?

Hi everyone!
Let’s assume that I have at hand:

  • a function S(x,y) which is symmetric with respect to x and y
  • an input tensor a

and that I want to compute the gradient of f(a) = S(a,a) with respect to a, plus higher order derivatives.
A standard way of doing this in PyTorch is to type:

import torch

S = lambda x,y : (-x*y).exp() # or any other symmetric function of (x,y)
a = torch.tensor([3.], requires_grad=True)

f = S(a,a) # = exp(-a*a)
g = torch.autograd.grad( f, [a], create_graph=True)[0]
# = -a*exp(-a*a) + -a*exp(-a*a)  -> inefficient. We should do better!
h = torch.autograd.grad( g, [a], create_graph=True)[0]
# = exp(-a*a) * (4*a*a - 2)
print(a, f, g, h)
print( (-a*a).exp()*(4*a*a-2))

However, in theory, I could do much better: since I know that S(x,y) is symmetric, I also know that

dS/dx(a,a) = dS/dy(a,a)

and I would thus like to compute df/da = dS(a,a)/da as

df/da(a) = 2*dS/dx(a,a)     instead of   df/da(a) = dS/dx(a,a) + dS/dy(a,a),

dividing by two my compute time. Is this possible? Well, kind of. Using the standard .detach() methods, etc., I was able to compute the right value for df/da, but could not get a fully differentiable expression:

a = torch.tensor([3.], requires_grad=True)
b = a.detach().requires_grad_()
f = S(a,b) # = exp(-a*b)
g = torch.autograd.grad( 2*f, [b], create_graph=True)[0]
# = -2*a* exp(-a*b)           -> efficient and correct, since =
h = torch.autograd.grad(   g, [a], create_graph=True)[0]
# = exp(-a*b) * (2*a*b - 2 )  -> This is not what we want!
#                                The 4 was replaced by a 2!
print(a, f, g, h)
print( (-a*b).exp()*(2*a*b-2))

Did I miss something? Is there a workaround?
Using undocumented low-level methods (?), I was thinking of replacing the references to b in g.grad_fn by references to a… Is this possible?

This would be really useful to many researchers who are currently using PyTorch for shape analysis. Symmetric functions pop up all the time as high-dimensional kinetic energies, etc., and we would like to be able to backprop through their expressions without having to compute the same derivative twice.

You stopped gradient at b, which should have back proped to a. If you don’t use detach, it works

a = torch.tensor([3.], requires_grad=True)
b = a.view_as(a).requires_grad_()
S = lambda x,y : x*y
f = S(a,b) # = exp(-a*b)
g = torch.autograd.grad( 2*f, [b], create_graph=True)[0]
h = torch.autograd.grad(   g, [a], create_graph=True)[0]