I have an auto-encoder network f and a loss:
I only want to collect the gradients from the first application of f and not the second during optimization.
One solutions is to create two copies of f: L(f2(f1(x))). Then only optimize f1, and copy the state_dict to f2 after each optimizer step.
Is there a more efficient (cleaner) solution?
(If I wanted to collect the gradients from the second application (f2), I could compute f1 and then detach it… but I don’t know how to do that when I want to collect the gradients from the first application (f1).