Hello there!
I want to compute the gradients of a model that has a torch.no_grad operation in the forward function.
The resulting grad in this example is None although x.requires_grad is True when I debug the forward.
Here is an example script:
import torch
import torch.nn as nn
class Network(nn.Module):
def __init__(self, in_dim=10, out_dim=1):
super(Network, self).__init__()
self.instancenorm = nn.InstanceNorm1d(in_dim)
self.fc = nn.Linear(in_dim, out_dim)
def forward(self, x):
with torch.no_grad():
x = self.instancenorm(x)
x = self.fc(x)
return x
model = Network(in_dim=10, out_dim=1)
model.eval()
x = torch.rand(1,10)
x.requires_grad_(True)
score = model(x)
grad = torch.autograd.grad(
outputs=score,
inputs=x,
allow_unused=True)[0]
print(grad)
When I remove
with torch.no_grad(): ...
it works, however setting x.requires_grad_(True) after torch.no_grad() also results in None.
What causes this and how can I solve it? And why is it no problem in training?
Thanks!