Multi-class semantic segmentation using U-Net Error with Binary Cross Entropy with Logits

My current implementation for the loss function of Pytorch Multi-class (class = 5) on my U-Net with Pre-trained ResNet is the ff:

def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

Error is the ff:

   2432     if not (target.size() == input.size()):
-> 2433         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2434 
   2435     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([4, 1, 320, 480, 3])) must be the same as input size (torch.Size([4, 5, 320, 480]))

F.binary_cross_entropy_with_logits expects the model output and target to have the same shape and can be used for a multi-label segmentation (each pixel can belong to zero, one, or multiple classes).

Since you’ve mentioned a multi-class segmentation (each pixel belongs to one class only), you should use nn.CrossEntropyLoss instead (or nn.NLLLoss with F.log_softmax as the last non-linearity).

nn.CrossEntropyLoss expects the model output to have the shape [batch_size, nb_classes, height, width] and the target [batch_size, height, width] containing the class indices in the range [0, nb_classes-1].

Your target seems to have an additional dimension in dim1 and might be a color image in the channels-last format?
If so, you would have to map the colors to class indices.

I already made a possible function to convert the mask image into n_channels of [n_classes, height, width] as a binary mask where background is 0 and 1. Also, each pixel only belongs to one class.

mask_palette = [1, 2, 3, 4, 5] #basically the values of each pixel converted to one channel. for each classes.

for palette in mask_palette:
  new_mask = np.zeros((mask_array.shape[0], mask_array.shape[1]), dtype=np.float32)
  new_mask[mask == palette] = 1
  final_mask.append(new_mask)

final_mask = np.asarray(final_mask)

final mask is [5, 600, 900]

So this method shall create
[batch_size, nb_classes, height, width] as a mask dataloader

and the same size as generated by my model.

What loss and activation should that be?

The final_mask is like the following

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

Hi, if I am going to use CrossEntropyLoss, what should be the setup of the output and target?

Currently, I have 5 classes. If I make the mask a 1 channel image with class numbers, should I put 0 as background? So I shall have 0, 1, 2, 3, 4, 5 so 6 classes?

If you are using nn.CrossEntropyLoss, you should have to use this approach:

Your target mask should not contain a channels dimension and for 5 classes should contain the class indices in the range [0, 4]. If the background class is not included in these 5 classes, then you should use 6. It doesn’t matter, if the background is class index 0 or any other.

Can you explain for me, why we should multiply target.size(0) in loss function?

Hello ,
I want to use CrossEntropyLoss for my segmentation code but I confused!
I calculated the loss in two ways and I found two difference results!!! But both of them recommended in the the pytorch forums !
.e.g my code like it :

y = torch.randint(low=0,high=4,size=( 2,1,3,3)).type(torch.LongTensor)
y_hat = torch.randn(2,4,3,3)
print('\ntarget shape like it  [batch_size,N_class,h,w] =',y.shape)
print('\npredict shape like it [batch_size,N_class,h,w] =',y_hat.shape)

print('\nloss without flatten = ', criterion(y_hat,y.squeeze(1)))
print('\nloss with flatten = ',criterion(y_hat.view(-1,4), y.view([-1, ])))

and the output :

target shape like it  [batch_size, 1 ,h,w] = torch.Size([2, 1, 3, 3])

predict shape like it [batch_size,N_class,h,w] = torch.Size([2, 4, 3, 3])

loss without flatten =  tensor(2.0514)

loss with flatten =  tensor(1.6351)

Which one is true ?

The view operation is wrong as it will interleave the tensors and you would need to permute it beforehand:

criterion = nn.CrossEntropyLoss()
y = torch.randint(low=0,high=4,size=( 2,1,3,3)).type(torch.LongTensor)
y_hat = torch.randn(2,4,3,3)
print('\ntarget shape like it  [batch_size,N_class,h,w] =',y.shape)
print('\npredict shape like it [batch_size,N_class,h,w] =',y_hat.shape)

print('\nloss without flatten = ', criterion(y_hat,y.squeeze(1)))
print('\nloss with flatten = ',criterion(y_hat.permute(0, 2, 3, 1).contiguous().view(-1,4), y.permute(0, 2, 3, 1).contiguous().view([-1, ])))

Output:

target shape like it  [batch_size,N_class,h,w] = torch.Size([2, 1, 3, 3])

predict shape like it [batch_size,N_class,h,w] = torch.Size([2, 4, 3, 3])

loss without flatten =  tensor(1.6820)

loss with flatten =  tensor(1.6820)