I have two segmentation networks and need to obtain the gradients of the output with respect to the input. It works on one model. On the second model, however, it throws a checkpointing error, which is interesting since there’s no checkpoint performed.
RuntimeError: Checkpointing is not compatible with .grad() or when an
inputsparameter is passed to .backward(). Please use .backward() and do not pass its
inputs argument.
I added a minimal code snippet to make the case more clearer. net_a
and net_b
are the segmentation networks.
net_a = NetA()
#net_b = NetB()
net_a.eval()
with torch.no_grad():
x = torch.rand(1, 3, 768, 768)
y = torch.LongTensor(1,768,768).random_(0, 19)
loss = nn.CrossEntropyLoss()
x = x.clone().detach()
y = y.clone().detach()
x_ = x.clone().detach()
x_.requires_grad = True
y_ = net_a(x_)
y_ = y_['logits']
loss = loss(y_, y)
print(loss.requires_grad, x_.requires_grad)
grad_ = torch.autograd.grad(loss, x_, retain_graph=False, create_graph=False)[0]
What might cause this behaviour? Both nets are structured very similarly hence the confusion up until now. Looking forward for a discussion!