I want to implement a pytorch model which can set the gradient manually during backpropogation. So here is an example.
a = torch.randn([2,2], requires_grad=True)
b = 2 * a
c= func(b)
d = 2 * c
Here func is a differentiable funciton and has explicit form. However, direct backpropogation will lead to huge computation demand. Here is an algorithm which can directly compute d_c/d_b without using backpropogation. Is there any way I can readily use to achieve the process: first backpropogation from d->c, manually set gradients from c->b, last propogation from b->a and finally get gradient from d->a?
I think if I split up the whole backpropogation process it can be completed but it’s still too complicated for me. I also try to utilize something like backward_hook but I cannot feed any input into that function. Grateful if you can help me with this problem.