Is it possible to no_grad() a downstream part of a network while still allowing the gradients to flow to the upstream part?

Say I have two DNN modules F and G.

out = F(G(x_1), x_2)

And a loss function like loss(out, label).

I want to disable gradients for F during one particular update (though not necessarily in others).

I tried this during that update:

g = G(x_1)
with torch.no_grad():
   out = F(g, x_2)

But (correct me if I’m wrong), it seems gradients don’t propagate back to G in this case. I need gradients to still propagate back to G.

Your code won’t work and will raise:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

if you try to call backward() on out or an output of any successive operation on it, since it was calculated in the no_grad block:

G = nn.Linear(1, 1)
F = nn.Linear(1, 1)
x = torch.randn(1, 1)

g = G(x)
with torch.no_grad():
   out = F(g)

out.backward()
> RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

print(out.grad_fn)
> None

To disable the gradient calculation in F you could set the requires_grad attributes of all parameters of F to False:

G = nn.Linear(1, 1)
F = nn.Linear(1, 1)
x = torch.randn(1, 1)

for param in F.parameters():
    param.requires_grad = False

g = G(x)
out = F(g)

out.backward()
print(out.grad_fn)
> <AddmmBackward object at >
print(G.weight.grad)
> tensor([[0.2877]])
print(F.weight.grad)
> None