Tversky for multiclass segmentation

Hey,

Does anyone know how I can implement a Tversky loss in a multiclass segmentation problem (I have 4 classes)? I saw this on github Multi class · Issue #3 · nabsabraham/focal-tversky-unet · GitHub, I change the code to pytorch, but did not understand this permutation. And when I try in my model I got this following error: “RuntimeError: number of dims don’t match in permute”.
My code:

class TverskyLoss(nn.Module):
    def __init__(self, alpha_t=0.5, beta_t=0.5,
    #weight=True, size_average=True,
    ):
        super(TverskyLoss, self).__init__()
        self.alpha_t = alpha_t
        self.beta_t = beta_t
        

    def forward(self, label, pred, smooth=1):
        # comment out if your model contains a sigmoid or equivalent activation layer
        #label = torch.sigmoid(label)

        # flatten label and prediction tensors

        label = label.permute(3,1,2,0)
        pred = pred.permute(3,1,2,0)

        flat_label = label.flatten()
        flat_pred = pred.flatten()
        TP = torch.sum(flat_label * flat_pred, 1)
        FN = torch.sum(flat_label * (1-flat_pred), 1)
        FP = torch.sum((1-flat_label)*flat_pred, 1)

        Tversky = (TP + smooth) / (TP + self.alpha_t * FP + self.beta_t * FN + smooth)
        print(self.alpha_t, self.beta_t)

        return 1 - Tversky

Someone could help me?

The permutations assume 4-dimensional tensors.
Here comes the first difference to Keras/TF: In PyTorch these will be Batch, Channel/Class, Height, Width, wit the channel containing the class label (in TF it’s BHWC, as pointed out in the comment you linked).
So what you want is that TP FN and FP sum over B, H and W (you could do that by doing torch.sum(label * pred, dim=(0, 2, 3)) apparently, so you would get a vector with just the class dimension, i.e. a score per class.
Note that this convention is different to the usual averaging over batches which may or may not affect what you want to do when changing batch sizes (depends on optimizer etc.).

Best regards

Thomas

Thank you, @tom . The problem that I am dealing now is that my target it is not in the BCHW shape, but in BHW. My prediction in other hand it is on the BCHW shape. So what should I do about it? Which transformation in the tensors should I do before the TP, FN, and FN calculation.

Best,
Giulia

So your target is class numbers?
Then target_onehot = (torch.arange(num_classes)[None, :, None, None] == label[:, None, :, :]).to(pred.dtype) could be something like the label on one-hot-encoded format that the code snippet seemed to use. (But do try it on a test example, I am just typing it into the answer here without testing it enough.)
Of course, it might be neat to look into avoiding this expansion, but you could always revisit that when you find it becomes a bottleneck.

Best regards

Thomas

Okay. Thank you @tom. I will try that for sure.

Hey, it is possible to apply some weight in each class in the tversky loss? I was doing it in the cross-entropy, but I did not find anything about it in tversky.

Best Regards,
Giulia

If you have a vector with the weights, you can multiply it to the score per class before summing them.

Best regards

Tomas