Unrelated to your question, but note that nn.CrossEntropyLoss
expects logits as the model output not probabilities coming from softmax.
Internally F.log_softmax
and nn.NLLLOSS
will be used so you can just remove the softmax as the output activation.
Also note that you can call torch.arxmax
directly without transforming to bumpy and back to PyTorch.