Hi, I’m new in pytorch and I’m trying to implement the revnet paper in order not to save the activations during the forward step and to recompute them during the backprop. The model should be built from blocks like in the following image (a) where F and G are some modules (bn -> relu -> conv -> …)
I thought to implement the above block as an autograd function and F and G as modules that I will call from the the forward func. During the backward we can compute x_1 and x_2 knowing y_1 and y_2 (img (b)) and to recompute the activations of the modules. However I’m still nor sure how to implement those things with pytorch:
- How to compute and apply the gradients of F? Say I know the grad of the output and I know the input and I have the module F, how can I compute dF/dx_2 and dF/dw_F?
- Will setting requires_grad=False for everything prevent the activations from being saved? and will it not interrupt the manual computation and applying of the gradients? (I guess it might be dependent on the solution to question 1)