nn.CrossEntropyLoss
expects a torch.LongTensor
containing the class indices without the channel dimension. In your case, you could simply use:
targets = torch.argmax(targets, 1)
to create your target tensor.
nn.CrossEntropyLoss
expects a torch.LongTensor
containing the class indices without the channel dimension. In your case, you could simply use:
targets = torch.argmax(targets, 1)
to create your target tensor.