Checkpointing is not compatible with .grad()

I have two segmentation networks and need to obtain the gradients of the output with respect to the input. It works on one model. On the second model, however, it throws a checkpointing error, which is interesting since there’s no checkpoint performed.

RuntimeError: Checkpointing is not compatible with .grad() or when an inputsparameter is passed to .backward(). Please use .backward() and do not pass itsinputs argument.

I added a minimal code snippet to make the case more clearer. net_a and net_b are the segmentation networks.

net_a = NetA()
#net_b = NetB()
net_a.eval()

with torch.no_grad():
    x = torch.rand(1, 3, 768, 768)
    y = torch.LongTensor(1,768,768).random_(0, 19)

loss = nn.CrossEntropyLoss()
x = x.clone().detach()
y = y.clone().detach()

x_ = x.clone().detach()
x_.requires_grad = True
y_ = net_a(x_)
y_ = y_['logits']

loss = loss(y_, y)
print(loss.requires_grad, x_.requires_grad)
grad_ = torch.autograd.grad(loss, x_, retain_graph=False, create_graph=False)[0]

What might cause this behaviour? Both nets are structured very similarly hence the confusion up until now. Looking forward for a discussion!

Could you make the code snippet executable by adding the missing pieces so that we could reproduce and bug it, please?