One-hot encoded targets are wrong if you are using nn.CrossEntropyLoss
, as the class indices are expected. In case you have already created the one-hot encoded targets, use target = torch.argmax(target, dim=1)
to create the expected target.
1 Like