# Computing the gradient of a symmetric function: graph surgery needed?

Hi everyone!
Let’s assume that I have at hand:

• a function `S(x,y)` which is symmetric with respect to `x` and `y`
• an input tensor `a`

and that I want to compute the gradient of `f(a) = S(a,a)` with respect to `a`, plus higher order derivatives.
A standard way of doing this in PyTorch is to type:

``````import torch

S = lambda x,y : (-x*y).exp() # or any other symmetric function of (x,y)

f = S(a,a) # = exp(-a*a)
# = -a*exp(-a*a) + -a*exp(-a*a)  -> inefficient. We should do better!
# = exp(-a*a) * (4*a*a - 2)
print(a, f, g, h)
print( (-a*a).exp()*(4*a*a-2))
``````

However, in theory, I could do much better: since I know that `S(x,y)` is symmetric, I also know that

``````dS/dx(a,a) = dS/dy(a,a)
``````

and I would thus like to compute `df/da = dS(a,a)/da` as

``````df/da(a) = 2*dS/dx(a,a)     instead of   df/da(a) = dS/dx(a,a) + dS/dy(a,a),
``````

dividing by two my compute time. Is this possible? Well, kind of. Using the standard `.detach()` methods, etc., I was able to compute the right value for `df/da`, but could not get a fully differentiable expression:

``````a = torch.tensor([3.], requires_grad=True)
f = S(a,b) # = exp(-a*b)
# = -2*a* exp(-a*b)           -> efficient and correct, since a.data = b.data...
# = exp(-a*b) * (2*a*b - 2 )  -> This is not what we want!
#                                The 4 was replaced by a 2!
print(a, f, g, h)
print( (-a*b).exp()*(2*a*b-2))
``````

Did I miss something? Is there a workaround?
Using undocumented low-level methods (?), I was thinking of replacing the references to `b` in `g.grad_fn` by references to `a`… Is this possible?

This would be really useful to many researchers who are currently using PyTorch for shape analysis. Symmetric functions pop up all the time as high-dimensional kinetic energies, etc., and we would like to be able to backprop through their expressions without having to compute the same derivative twice.

You stopped gradient at `b`, which should have back proped to a. If you don’t use detach, it works

``````a = torch.tensor([3.], requires_grad=True)