I want to compute the gradients of a model that has a torch.no_grad operation in the forward function.
The resulting grad in this example is None although x.requires_grad is True when I debug the forward.
Here is an example script:
import torch import torch.nn as nn class Network(nn.Module): def __init__(self, in_dim=10, out_dim=1): super(Network, self).__init__() self.instancenorm = nn.InstanceNorm1d(in_dim) self.fc = nn.Linear(in_dim, out_dim) def forward(self, x): with torch.no_grad(): x = self.instancenorm(x) x = self.fc(x) return x model = Network(in_dim=10, out_dim=1) model.eval() x = torch.rand(1,10) x.requires_grad_(True) score = model(x) grad = torch.autograd.grad( outputs=score, inputs=x, allow_unused=True) print(grad)
When I remove
with torch.no_grad(): ...
it works, however setting x.requires_grad_(True) after torch.no_grad() also results in None.
What causes this and how can I solve it? And why is it no problem in training?