Local modifications of the backpropagation in Pytorch

Hi,

On pytorch, I am looking for a flexible way to modify the computation of the backpropagation in a neural network with multiple layers.

More precisely, I want to create blocks of layers indexed by k in {1,...,K} such that :

  1. the output is computed by the formula x_{k+1} = f_k(g_k(x_k, theta_k)), where g_k is the composition of multiple layers with learnable parameters theta_k, and f_k is some function that does not have learnable parameters.
  2. during the backpropagation, the gradient with respect to theta_k is calculated using the above formula.
  3. still during the backpropagation, the gradient with respect to x_k is computed as if the operation made was x_{k+1} = h_k(g_k(x_k, theta_k)), where h_k is a different function from f_k

I tried using hooks but I could not figure out a simple way to do this. I asked ChatGPT4 multiple times but everytime the answer does not solve my problem and it takes a long time to check. Any help would be much appreciated.

Hey!

This is going to be a couple possible approaches here:

  • You can override the whole block to do whatever you want with a custom Function (Extending PyTorch — PyTorch 2.2 documentation). In this case, you might want to do “reentrant backward” where your custom backward call backward on other pieces. You will need to properly handle enable_grad to get the behavior you want (some related example is pytorch/torch/utils/checkpoint.py at d3839b624b5f6451a13bd9b5ecbbce4c2a9b1db6 · pytorch/pytorch · GitHub).
  • You can override only one smaller function (f_k) to have a custom backward behavior (should be simpler than the above), but the problem is that since you want gradients for each input to be computed with a different backward for f_k, you will have to do two calls to backward: one wrt x_k with the original f_k gradient and one wrt theta_k with the modified gradient.
1 Like

Thank you for the reply. I alerady heard about the first approach you are proposing, and I wanted to avoid it, but I think I will go for this one. I do not think the second one may work in my situation.