How to avoid recalculating a function when we need to backpropagate through it twice?

I want to do the following calculation:

l1 = f(x.detach(), y)
l1.backward()
l2 = -1*f(x, y.detach())
l2.backward()

where f is some function, and x and y are tensors that require gradient. Notice that x and y may both be the results of previous calculations which utilize shared parameters (for example, maybe x=g(z) and y=g(w) where g is an nn.Module ).

The issue is that l1 and l2 are both numerically identical, up to the minus sign, and it seems wasteful to repeat the calculation f(x,y) twice. It would be nicer to be able to calculate it once, and apply backward twice on the result. Is there any way of doing this?

One possibility is to manually call autograd.grad and update the w.grad field of each nn.Parameter w . But I’m wondering if there is a more direct and clean way to do this, using the backward function.

Hi,

I’m afraid the only thing you can do is to give retain_graph=True the first time you call backward.

Can you please elaborate on how this would help?

This flag prevents the graph internal states from being freed. So you will be able to backprop through the graph as many times as you want (as mentioned in the error message you shared).

Thanks. I would appreciate if you can give a code snippet demonstrating your suggestion. I am aware of retain_graph, but I don’t see how setting it to True will help accomplish what I asked for. It seems to me that there is still a missing ingredient that would let me to do backward twice on the retained graph, each time keeping a different part of the graph detached.

l1 = f(x.detach(), y)
l1.backward(retain_graph=True)
l2 = -1*f(x, y.detach())
l2.backward()

This should not throw the error you mentioned :slight_smile:

Thank you, but I did not mention any error. I wanted to avoid calculating the function f twice. I think that in your solution, it is calculated twice, right? Unless PyTorch automatically caches the first calculation of f and then in the second invocation the cached graph is reused somehow, including the detached subgraph composed of the ancestors of x?

Ho sorry I think I misread the title of the topic and though it was an error :frowning:

Does the function f has any parameter into it?

If there is nothing else in there, and the only way to get to the parameters is via x and y, I would do:

x, y = g(input, params)
# f must have NO parameters
# Equivalent (by linearity of the gradient) to 
# l1 = f(x.detach(), y)
# l1.backward()
# l2 = -1*f(x, y.detach())
# l2.backward()
y.register_hook(lambda x: -x)
loss = f(x, y)
loss.backward()

Note that you should double check that you get the same gradients still as I might be missing something :slight_smile:

This solved my problem nicely. Thanks.
Here is code demonstrating that the solution works:

import torch
lin = torch.nn.Linear(1,1,bias=False)
lin.weight.data[:] = 1.0
a = torch.tensor([1.0])
b = torch.tensor([2.0])
loss_func = lambda x,y: (x-y).abs()

# option 1: this is the inefficient option, presented in the original question
lin.zero_grad()
x = lin(a)
y = lin(b)
loss1 = loss_func(x.detach(),y)
loss1.backward(retain_graph=True)
loss2 = -1*loss_func(x,y.detach()) # second invocation of `loss_func` - not efficient!
loss2.backward()
print(lin.weight.grad)

# option 2: this is the efficient method, suggested by @albanD. 
lin.zero_grad()
x = lin(a)
y = lin(b)
x.register_hook(lambda t: -t)
loss = loss_func(x,y) # only one invocation of `loss_func` - more efficient!
loss.backward()
print(lin.weight.grad) # the output of this is identical to the previous print, which confirms the method

# option 3 - this should not be equivalent to the previous options, used just for comparison
lin.zero_grad()
x = lin(a)
y = lin(b)
loss = loss_func(x,y)
loss.backward()
print(lin.weight.grad)

However, it would be nice to know if there is a more general alternative, that works also for non-linear cases. In other words, it would be interesting to know how to perform the following calculation, but reducing to a single invocation of f(x,y), instead of two:

l1 = u(f(x.detach(), y))
l1.backward()
l2 = v(f(x, y.detach()))
l2.backward()

Here, u and v are two functions.

Hi,

non-linear cases

The gradient is always linear :slight_smile:
This won’t work if u/v/f have paramters though because the flipping of the gradient for the second loss only happens before f. So parameters in these functions would just see the sum of the two losses.