I have two segmentation networks and need to obtain the gradients of the output with respect to the input. It works on one model. On the second model, however, it throws a checkpointing error, which is interesting since there’s no checkpoint performed.
RuntimeError: Checkpointing is not compatible with .grad() or when an inputs
parameter is passed to .backward(). Please use .backward() and do not pass itsinputs
I added a minimal code snippet to make the case more clearer.
net_b are the segmentation networks.
net_a = NetA() #net_b = NetB() net_a.eval() with torch.no_grad(): x = torch.rand(1, 3, 768, 768) y = torch.LongTensor(1,768,768).random_(0, 19) loss = nn.CrossEntropyLoss() x = x.clone().detach() y = y.clone().detach() x_ = x.clone().detach() x_.requires_grad = True y_ = net_a(x_) y_ = y_['logits'] loss = loss(y_, y) print(loss.requires_grad, x_.requires_grad) grad_ = torch.autograd.grad(loss, x_, retain_graph=False, create_graph=False)
What might cause this behaviour? Both nets are structured very similarly hence the confusion up until now. Looking forward for a discussion!