I want to differentiate the output of a model with respect to its input. Here is the code (self is a class similar to nn.Module, x is the input, which is a batch of images)
lsm = torch.nn.LogSoftmax(dim=1)
x.requires_grad = True
self.eval()
logits = self.evaluate(x)
out_lsm = lsm(logits)
aux_sum = torch.sum(torch.max(out_lsm, axis=1)[0])
aux_sum.backward()
I want to find the gradient of aux_sum
with respect to x
, but for some reason aux_sum
has no gradient function and I get the error element 0 of tensors does not require grad and does not have a grad_fn
As a sanity check, I can do the same process but with a simpler “model”:
x = torch.Tensor([0, 1, 2])
x.requires_grad = True
out_lsm = x ** 2
aux_sum = out_lsm.sum()
aux_sum.backward()
print(x.grad)
>>> tensor([0., 2., 4.])
with no issues. Does anyone know what is causing the gradient issue in the first case, and why it’s not the same as the simplified second case? Thank you!