Does anyone know how I can implement a Tversky loss in a multiclass segmentation problem (I have 4 classes)? I saw this on github Multi class · Issue #3 · nabsabraham/focal-tversky-unet · GitHub, I change the code to pytorch, but did not understand this permutation. And when I try in my model I got this following error: “RuntimeError: number of dims don’t match in permute”.
def __init__(self, alpha_t=0.5, beta_t=0.5,
self.alpha_t = alpha_t
self.beta_t = beta_t
def forward(self, label, pred, smooth=1):
# comment out if your model contains a sigmoid or equivalent activation layer
#label = torch.sigmoid(label)
# flatten label and prediction tensors
label = label.permute(3,1,2,0)
pred = pred.permute(3,1,2,0)
flat_label = label.flatten()
flat_pred = pred.flatten()
TP = torch.sum(flat_label * flat_pred, 1)
FN = torch.sum(flat_label * (1-flat_pred), 1)
FP = torch.sum((1-flat_label)*flat_pred, 1)
Tversky = (TP + smooth) / (TP + self.alpha_t * FP + self.beta_t * FN + smooth)
return 1 - Tversky
Someone could help me?
The permutations assume 4-dimensional tensors.
Here comes the first difference to Keras/TF: In PyTorch these will be Batch, Channel/Class, Height, Width, wit the channel containing the class label (in TF it’s BHWC, as pointed out in the comment you linked).
So what you want is that TP FN and FP sum over B, H and W (you could do that by doing
torch.sum(label * pred, dim=(0, 2, 3)) apparently, so you would get a vector with just the class dimension, i.e. a score per class.
Note that this convention is different to the usual averaging over batches which may or may not affect what you want to do when changing batch sizes (depends on optimizer etc.).
Thank you, @tom . The problem that I am dealing now is that my target it is not in the BCHW shape, but in BHW. My prediction in other hand it is on the BCHW shape. So what should I do about it? Which transformation in the tensors should I do before the TP, FN, and FN calculation.
So your target is class numbers?
target_onehot = (torch.arange(num_classes)[None, :, None, None] == label[:, None, :, :]).to(pred.dtype) could be something like the label on one-hot-encoded format that the code snippet seemed to use. (But do try it on a test example, I am just typing it into the answer here without testing it enough.)
Of course, it might be neat to look into avoiding this expansion, but you could always revisit that when you find it becomes a bottleneck.
Okay. Thank you @tom. I will try that for sure.
Hey, it is possible to apply some weight in each class in the tversky loss? I was doing it in the cross-entropy, but I did not find anything about it in tversky.
If you have a vector with the weights, you can multiply it to the score per class before summing them.