Only collect gradients from part of the graph

I have an auto-encoder network f and a loss:

I only want to collect the gradients from the first application of f and not the second during optimization.

One solutions is to create two copies of f: L(f2(f1(x))). Then only optimize f1, and copy the state_dict to f2 after each optimizer step.

Is there a more efficient (cleaner) solution?

(If I wanted to collect the gradients from the second application (f2), I could compute f1 and then detach it… but I don’t know how to do that when I want to collect the gradients from the first application (f1).


Do it this way:

from torch.autograd import grad
y = f(x)
y_d = y.detach()
l = L(f(y_d))
y_grad = grad(l, y_d)
y = f(x)
for p in f.parameters(): p.requires_grad=False
z = f(y)
for p in f.parameters(): p.requires_grad=True

I remember this works.

1 Like

I never really understand when PyTorch uses the requires_grad attribute. It is certainly used in forward time, however sometimes it also seems to have implication on the backward pass. Do you know more details? Thanks in advance :slight_smile:

It seems like to me that the autograd engine is recording reference of computation. Once you created the computation, it’s recorded in the engine. (The bad thing is you cannot dynamically change the graph you have already created)

1 Like

Thanks! This seems to work and is very elegant.

I guess the idea is that when y gets added to the computation graph, f requires_grad=True.
Then setting f to requires_grad=False does not affect y as it is already part of the graph.
But now when adding z, no new gradients are required.
Is that the correct reasoning?

Thanks again