Change to criterion = torch.nn.CrossEntropyLoss, Similar issue here.
criterion = torch.nn.CrossEntropyLoss,