I have the following sequence of operations:
x = e(s)
y = f(x)
z1 = g1(y)
z2 = g2(y)
The gradient from z1
would backprop through g1
, f
and e
etc. But I want the gradient from z2 to stop at x
.
I can do this by using .detach()
on x
z1 = g1(f(x))
z2 = g2(f(x.detach()))
But this means that I now make two calls to f
. Is there a better way of doing this? Like detaching and copying the relevant part of the computation graph?
I believe one approach would be to use the new inputs
argument in the backward
operation, which allows you to specify the tensors, which should accumulate the gradients. In your case you could pass the parameters from f
and g2
to the backward
operation created by z2
. Another possibility could be to detach
the tensor as suggested by you, retain_grad()
on x
during the backward pass of z1
and pass this gradient to x.backward(grad)
to continue the backpropagation through e
.