Training loop for segmentation

I am to trying to learn how to train a network for segmentation in pytorch.
I run into the problem of selecting a proper loss function and thus constructing a proper training loop.
The masks of the images are {0,1}-valued matrices, each channel represents different label (4 channels total). Which loss function should I use for this? Should it apply separately to every channel (class) and then be averaged?
I have this training loop for a UNet model:

for epoch in range(total):
    pbar = tqdm(trainloader);
    for (x, mask) in pbar:
        y_pred = torch.softmax(model(x.cuda()), dim=1)
        loss = criterion(y_pred,
        running_loss += (loss.item())/len(trainloader)
        pbar.set_postfix({'Epoch': f'{epoch+1}/{total}',
        if (i+1)%acc_steps==0: # accumulation of gradient

and the loss i am using now is the dice coefficient loss per channel and then averaged

class DiceLoss(nn.Module):
    def __init__(self, eps=1e-15):
        self.eps = eps
    def forward(self, y_pred, y_true):
        y_true: mask, torch.Tensor of shape NxCxHxW
        y_pred: prediction, torch.Tensor of shape NxCxHxW
        Here C is number of classes (4)
        num_classes = y_true.size(1)
        bs = y_true.size(0)
        dice_coef = 0.
        for cls in range(num_classes):
            intersection = y_true[:,cls].float() * y_pred[:,cls].float()
            union = y_true[:,cls].float() + y_pred[:,cls].float()
            dice_coef += ((2*intersection.sum() + self.eps)/(union.sum()+self.eps))
        return 1.-dice_coef #maximize dice coef by minimizing this value

Is this a correct approach for the Dice coefficient-based loss function and the training loop overall?

Thanks for any advice.

Alternatively to the dice loss you could also use e.g. nn.CrossentropyLoss, if each pixel belongs to one certain class (note that you would need to pass the logits to this loss function).

The training loop looks alright. I’m wondering, why you are calling empty_cache() inside the loop, as this could make your code slower and shouldn’t save any memory.

1 Like

Would crossentropy loss work if the target mask has shape Nx4xHxW? I thought that the mask for cross entropy needs to be NxHxW only.

I will try it.
I use empty_cache() because I actually thought that it does save up some gpu memory, the images are very large (and the batch size is very small, hence grad accumulation).

If the target is currently a one-hot encoded tensor, you should pass torch.argmax(target, 1) as the target to nn.CrossEntropyLoss.

1 Like