Pytorch Cross Entropy Loss number Weights not matching number of classes?

I’m not quite sure what I’ve done wrong here, or if this is a bug in PyTorch. I’m trying to predict a number of classes - 5 in this case - but one of them, class 0, dominates over all others. This is the background class essentially and we aren’t too interested in it. So I want to use the weights in the cross entropy function to emphasise the other 4 classes.

I was reading the documentation for the CrossEntropyLoss — PyTorch 1.9.1 documentation

… and it seems like the shape for the first argument tensor is (minibatch,C,d1​,d2​,…,dK​)

Happy with that. C is apparently the number of classes, which is 5 in my case. That would be 0 to 4 inclusive.

Now according to the docs “weight (Tensor, optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size C” - so that would be 5

Here is the code I have so far:

def loss_func(result, target) -> torch.Tensor:
    print(result.shape, target.shape)
    class_weights = torch.tensor([0.1, 1.0, 1.0, 1,0, 1.0], dtype=torch.float16, device=result.device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    dense = target.to_dense().long().to(result.device)
    loss = criterion(result, dense)
    return loss

And the output I get is the following:

torch.Size([2, 5, 26, 150, 320]) torch.Size([2, 26, 150, 320])
Traceback (most recent call last):
  File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 213, in <module>
    train(args, model, train_data, test_data, optimiser, writer)
  File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 118, in train
    loss = loss_func(result, target_mask)
  File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 61, in loss_func
    loss = criterion(result, dense)
  File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1120, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: weight tensor should be defined either for all or no classes at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:27

Now if I adjust the weight tensor to [0.1, 1.0, 1.0, 1.0] it works fine, or rather, I get NANs and have to add “ignore_index=0”.

I’m using a batch size of 2, with 3D images 320x150x26 pixels in size. So [2, 5, 26, 150, 320] seems correct to me. I wonder if there is something I’ve missed or if there is a bug? I am using float16 through-out and that has caused NANs on occasion but I think " weight tensor should be defined either for all or no classes" seems like a bug.

Has anyone else had this at all? Perhaps I just need to upgrade my pytorch setup? I’m running 1.9.0

I’ve put together a minimum viable example below:

import torch
from torch import nn

def loss_func_works(result, target) -> torch.Tensor:
    print(result.shape, target.shape)
    class_weights = torch.tensor([1.0, 1.0, 1,0, 1.0], dtype=torch.float32, device=result.device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    dense = target.long().to(result.device)
    loss = criterion(result, dense)
    return loss


def loss_func_fails(result, target) -> torch.Tensor:
    print(result.shape, target.shape)
    class_weights = torch.tensor([0.1, 1.0, 1.0, 1,0, 1.0], dtype=torch.float32, device=result.device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    dense = target.long().to(result.device)
    loss = criterion(result, dense)
    return loss


if __name__ == "__main__":
    result = torch.tensor((), dtype=torch.float32)
    result = result.new_ones((2, 5, 26, 150, 320))
    target = torch.tensor((), dtype=torch.float32)
    target = target.new_ones((2, 26, 150, 320))

    loss_func_works(result, target)
    loss_func_fails(result, target)

Hi Benjamin!

You have a simple typo where you have a comma (",") where you
should have a decimal point (".") – so you end up with one more
weight than you have classes:

>>> class_weights = torch.tensor ([0.1, 1.0, 1.0, 1,0, 1.0])
>>> class_weights.shape
torch.Size([6])
>>> class_weights = torch.tensor ([0.1, 1.0, 1.0, 1.0, 1.0])
>>> class_weights.shape
torch.Size([5])

Best.

K. Frank

Thanks! Good spot! All seems fine now.