Can I do modification on a branch of chain in the backward pass?

Hello, I want to do a kind of custom backpropagation with separate backward chains, but I don’t know how to do it with autograd.

Say I have a set of operations as:

y = f1(x)
u = f2(y)
v = f3(y)
z = f4(u, v)

Then, the the backward chain should be:

dz/dx = dz/dy *dy/dx
          = (dz/du * du/dy + dz/dv * dv/dy) * dy/dx       (1)

the above equation (1) can be further parted into two chains as:

dz/dx = (dz/du * du/dy * dy/dx) + (dz/dv *dv/dy * dy/dx)        (2)

However, I want to modify one branch of the backpropagate chain in equation (2), for example:

 dz/dx = (dz/du * du/dy * dy/dx)*mean(dz/du * du/dy * dy/dx)  +  (dz/dv *dv/dy * dy/dx) 

My understanding is that the autograd compute dz/dx according to equation (1) not equation (2). So the two backpropagate chains may never exist in the autograd computation. How can I implement this gradient modification in autograd?

If you want to compute these two contributions independently, then you will need to either:

  • create a graph with only that one branch using .detach() in the other
  • use hooks on the second branch to zero-out its contribution (u.register_hook(lambda x: torch.zeros_like(x))`).

That way you can get the two terms by doing two backward passes. And then you can accumulate the gradients from the two backward to get the final value that you want.

Thanks for the reply. But can it have a more elegant solution? By doing above solution, I have to do two forward and two backward computation for the gradient modified sub-network, which is very inefficient. Can I have a way to control how the autograd engine do backward propagation?

Using the hook version, you can do with a single forward and two backwards.

Can I have a way to control how the autograd engine do backward propagation?

There is but what you’re asking for here is a bit weird. It breaks one of the main assumption which is that if you had a Tensor foo during the forward, then the quantity being backpropagated for it is another Tensor of the same size.
In this case, you actually want two values to flow back at once for the top of your graph. So you implicitly do two backward passes. It is just that it is not explicit in your math.