I have an auto-encoder network f and a loss:

L(f(f(x)))

I only want to collect the gradients from the first application of f and not the second during optimization.

One solutions is to create two copies of f: L(f2(f1(x))). Then only optimize f1, and copy the state_dict to f2 after each optimizer step.

Is there a more efficient (cleaner) solution?

(If I wanted to collect the gradients from the second application (f2), I could compute f1 and then detach it… but I don’t know how to do that when I want to collect the gradients from the first application (f1).

Thanks