Hi, I am developing an Unet model for bio-medical images. I was wondering if I could pass to the function the predictions as B x C x H x W and the target as B x C x H x W, where for the channels I preprocessed the target mask so that along the C dimension there is a 1 for where the respective class aka label is. Is this the correct way? I have seen people saying otherwise.
Yes, this is possible in newer PyTorch releases since “soft-targets” are now supported (unsure which PyTorch releases added this functionality).
However, using the “classical” indices would save some memory since only the class index needs to be stored in the target.
Here is a small example showing both approaches:
criterion = nn.CrossEntropyLoss()
N, C, H, W = 2, 3, 4, 4
x = torch.randn(N, C, H, W, requires_grad=True)
# indices
target = torch.randint(0, C, (N, H, W))
loss = criterion(x, target)
print(loss)
# tensor(1.7541, grad_fn=<NllLoss2DBackward0>)
# soft-targets
soft_target = F.one_hot(target, num_classes=C).float()
print(soft_target.shape)
# torch.Size([2, 4, 4, 3])
soft_target = soft_target.permute(0, 3, 1, 2)
print(soft_target.shape)
# torch.Size([2, 3, 4, 4])
loss = criterion(x, soft_target)
print(loss)
# tensor(1.7541, grad_fn=<DivBackward1>)
1 Like
My dataset is heavily imbalanced, and I think my model is biased towards the background label. Would the ignore_index work for the one-hot encoded version? or should I just argmax and pass ignore_index=0, where 0 is the channel of the background?