Only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

I currently encountered the error. Sample code to reproduce the error:

prediction = torch.rand((2,3,5,5))
label = torch.randint(0,2,(2,3,5,5))
criterion = nn.CrossEntropyLoss()
loss = criterion(prediction, label)
Traceback (most recent call last):
  File "/home/moshood/Documents/AMMI/AMMI_Research_Project/", line 240, in <module>
    loss = criterion(prediction, label)
  File "/home/moshood/.local/lib/python3.8/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/moshood/.local/lib/python3.8/site-packages/torch/nn/modules/", line 961, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/moshood/.local/lib/python3.8/site-packages/torch/nn/", line 2468, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/home/moshood/.local/lib/python3.8/site-packages/torch/nn/", line 2266, in nll_loss
    ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

I read posts relating to Crossentroy with the label, not one-encoded, but in my case, the label per channel are either one or zero. I will be glad for your assistance.

As you said, nn.CrossEntropyLoss doesn’t accept one-hot encoded targets, which seems to be the case in your workflow. Pass either class indices or use e.g. nn.BCEWithLogitsLoss, if you are dealing with a multi-label classification.

1 Like

You mean taking the argmax of both predictions and labels?

I follow your approach and got the loss.

prediction = torch.rand((1,3,5,5), dtype=torch.float32)
label = torch.randint(0,2,(1,3,5,5), dtype=torch.float32)
criterion = nn.CrossEntropyLoss()
label = torch.argmax(label, 1)
print(label.shape, prediction.shape, label.dtype, prediction.dtype)
loss = criterion(prediction, label)
torch.Size([1, 5, 5]) torch.Size([1, 3, 5, 5]) torch.int64 torch.float32