On pytorch, I am looking for a flexible way to modify the computation of the backpropagation in a neural network with multiple layers.
More precisely, I want to create blocks of layers indexed by k in {1,...,K} such that :
the output is computed by the formula x_{k+1} = f_k(g_k(x_k, theta_k)), where g_k is the composition of multiple layers with learnable parameters theta_k, and f_k is some function that does not have learnable parameters.
during the backpropagation, the gradient with respect to theta_k is calculated using the above formula.
still during the backpropagation, the gradient with respect to x_k is computed as if the operation made was x_{k+1} = h_k(g_k(x_k, theta_k)), where h_k is a different function from f_k
I tried using hooks but I could not figure out a simple way to do this. I asked ChatGPT4 multiple times but everytime the answer does not solve my problem and it takes a long time to check. Any help would be much appreciated.
You can override only one smaller function (f_k) to have a custom backward behavior (should be simpler than the above), but the problem is that since you want gradients for each input to be computed with a different backward for f_k, you will have to do two calls to backward: one wrt x_k with the original f_k gradient and one wrt theta_k with the modified gradient.
Thank you for the reply. I alerady heard about the first approach you are proposing, and I wanted to avoid it, but I think I will go for this one. I do not think the second one may work in my situation.