Hi @bing, you can’t differentiate torch.argmax
with respect to output
(as torch.argmax
has no grad_fn
) so you need to find another way to convert your output
tensor to a prediction with an operation that has a grad_fn
. A minimal example below to show that torch.argmax
has no grad_fn
.
import torch
x=torch.randn(10,4,requires_grad=True)
output = torch.argmax(x, dim=1)
print(output.grad_fn) #returns None
You might just be able to remove the torch.argmax
call as your Loss seems to expect the raw logits and replace the loss calculation as,
loss = criterion(output.float(), target.float()). # Conversion of pred and target to float
More info in this post here (about logits with a different loss function you might find useful)