I’m not quite sure what I’ve done wrong here, or if this is a bug in PyTorch. I’m trying to predict a number of classes - 5 in this case - but one of them, class 0, dominates over all others. This is the background class essentially and we aren’t too interested in it. So I want to use the weights in the cross entropy function to emphasise the other 4 classes.
I was reading the documentation for the CrossEntropyLoss — PyTorch 1.9.1 documentation
… and it seems like the shape for the first argument tensor is (minibatch,C,d1,d2,…,dK)
Happy with that. C is apparently the number of classes, which is 5 in my case. That would be 0 to 4 inclusive.
Now according to the docs “weight (Tensor, optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size C” - so that would be 5
Here is the code I have so far:
def loss_func(result, target) -> torch.Tensor:
print(result.shape, target.shape)
class_weights = torch.tensor([0.1, 1.0, 1.0, 1,0, 1.0], dtype=torch.float16, device=result.device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
dense = target.to_dense().long().to(result.device)
loss = criterion(result, dense)
return loss
And the output I get is the following:
torch.Size([2, 5, 26, 150, 320]) torch.Size([2, 26, 150, 320])
Traceback (most recent call last):
File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 213, in <module>
train(args, model, train_data, test_data, optimiser, writer)
File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 118, in train
loss = loss_func(result, target_mask)
File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 61, in loss_func
loss = criterion(result, dense)
File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1120, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: weight tensor should be defined either for all or no classes at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:27
Now if I adjust the weight tensor to [0.1, 1.0, 1.0, 1.0] it works fine, or rather, I get NANs and have to add “ignore_index=0”.
I’m using a batch size of 2, with 3D images 320x150x26 pixels in size. So [2, 5, 26, 150, 320] seems correct to me. I wonder if there is something I’ve missed or if there is a bug? I am using float16 through-out and that has caused NANs on occasion but I think " weight tensor should be defined either for all or no classes" seems like a bug.
Has anyone else had this at all? Perhaps I just need to upgrade my pytorch setup? I’m running 1.9.0
I’ve put together a minimum viable example below:
import torch
from torch import nn
def loss_func_works(result, target) -> torch.Tensor:
print(result.shape, target.shape)
class_weights = torch.tensor([1.0, 1.0, 1,0, 1.0], dtype=torch.float32, device=result.device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
dense = target.long().to(result.device)
loss = criterion(result, dense)
return loss
def loss_func_fails(result, target) -> torch.Tensor:
print(result.shape, target.shape)
class_weights = torch.tensor([0.1, 1.0, 1.0, 1,0, 1.0], dtype=torch.float32, device=result.device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
dense = target.long().to(result.device)
loss = criterion(result, dense)
return loss
if __name__ == "__main__":
result = torch.tensor((), dtype=torch.float32)
result = result.new_ones((2, 5, 26, 150, 320))
target = torch.tensor((), dtype=torch.float32)
target = target.new_ones((2, 26, 150, 320))
loss_func_works(result, target)
loss_func_fails(result, target)