Say I have two DNN modules F
and G
.
out = F(G(x_1), x_2)
And a loss function like loss(out, label)
.
I want to disable gradients for F
during one particular update (though not necessarily in others).
I tried this during that update:
g = G(x_1)
with torch.no_grad():
out = F(g, x_2)
But (correct me if I’m wrong), it seems gradients don’t propagate back to G
in this case. I need gradients to still propagate back to G
.
Your code won’t work and will raise:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
if you try to call backward()
on out
or an output of any successive operation on it, since it was calculated in the no_grad
block:
G = nn.Linear(1, 1)
F = nn.Linear(1, 1)
x = torch.randn(1, 1)
g = G(x)
with torch.no_grad():
out = F(g)
out.backward()
> RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
print(out.grad_fn)
> None
To disable the gradient calculation in F
you could set the requires_grad
attributes of all parameters of F
to False
:
G = nn.Linear(1, 1)
F = nn.Linear(1, 1)
x = torch.randn(1, 1)
for param in F.parameters():
param.requires_grad = False
g = G(x)
out = F(g)
out.backward()
print(out.grad_fn)
> <AddmmBackward object at >
print(G.weight.grad)
> tensor([[0.2877]])
print(F.weight.grad)
> None