You could try the following code:
batch_size = 4
-torch.mean(torch.sum(labels.view(batch_size, -1) * torch.log(preds.view(batch_size, -1)), dim=1))
You could try the following code:
batch_size = 4
-torch.mean(torch.sum(labels.view(batch_size, -1) * torch.log(preds.view(batch_size, -1)), dim=1))