I am running multi-label segmentation of 3D data(
batch x classes x H x W x D). The target is 1-hot encoded[all 0s and 1s].
I have broad questions about the loss to be used.
In V-net paper: https://arxiv.org/pdf/1606.04797.pdf
The dice loss used as:
which varies from 0~1, and intended to maximize. Like regular IoU/Dice.
In Generalized dice paper: https://arxiv.org/pdf/1707.03237.pdf
GDL loss is:
and the author says about the weight:
when choosing the GDLv weighting, the contribution of each label is
corrected by the inverse of its volume, thus reducing the well-known correlation
between region size and Dice score
which I understand very well.
So, when I implement both losses with the following code from: pytorch/functional.py at rogertrullo-dice_loss · rogertrullo/pytorch · GitHub
torch.manual_seed(1001) out = Variable(torch.randn(3, 9, 64, 64, 64)) print >> tensor(5.2134) tensor(-5.4812) seg = Variable(torch.randint(0,2,[3,9,64,64, 64])) #target is in 1-hot-encoded format def dice_loss(prediction, target, epsilon=1e-6): """ prediction is a torch variable of size BatchxnclassesxHxW representing log probabilities for each class target is a 1-hot representation of the groundtruth, shoud have same size as the prediction """ assert prediction.size() == target.size(), "prediction sizes must be equal." assert prediction.dim() == 5, "prediction must be a 4D Tensor." uniques = np.unique(target.numpy()) assert set(list(uniques)) <= set([0, 1]), "target must only contain zeros and ones" probs = F.softmax(prediction, dim=1) # channel/classwise num = probs * target # b,c,h,w--p*g num = torch.sum(num, dim=4) num = torch.sum(num, dim=3) num = torch.sum(num, dim=2) # b,c den1 = probs * probs # --p^2 den1 = torch.sum(den1, dim=4) #b,c,h den1 = torch.sum(den1, dim=3) #b,c,h den1 = torch.sum(den1, dim=2) den2 = target * target # --g^2 den2 = torch.sum(den2, dim=4) den2 = torch.sum(den2, dim=3) den2 = torch.sum(den2, dim=2) den = (den1+den2+epsilon)#.clamp(min=epsilon) dice=2*(num/den) dice_eso=dice #[:,1:]#we ignore background dice val, and take the foreground dice_total=torch.sum(dice_eso)/dice_eso.size(0)#divide by batch_sz return dice_total dice_loss(prediction=out, target=seg)
I get a dice score of
1.9096 why is it beyond 1 and when using it as loss criterion during training, what limit it has to be maximized?
And when using GDL from: pytorch-3dunet/losses.py at eafaa5f830eebfb6dbc4e570d1a4c6b6e25f2a1e · wolny/pytorch-3dunet · GitHub
class GeneralizedDiceLoss(_AbstractDiceLoss): """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf. """ def __init__(self, normalization='sigmoid', epsilon=1e-6): super().__init__(weight=None, normalization=normalization) self.epsilon = epsilon def dice(self, prediction, target, weight): assert prediction.size() == target.size(), "'prediction' and 'target' must have the same shape" prediction = flatten(prediction) #flatten all dimensions except channel/class target = flatten(target) target = target.float() if prediction.size(0) == 1: # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf) # put foreground and background voxels in separate channels prediction = torch.cat((prediction, 1 - prediction), dim=0) target = torch.cat((target, 1 - target), dim=0) w_l = target.sum(-1) w_l = 1 / (w_l * w_l).clamp(min=self.epsilon) w_l.requires_grad = False intersect = (prediction * target).sum(-1) print(intersect.shape) intersect = intersect * w_l denominator = (prediction + target).sum(-1) print(denominator) denominator = (denominator * w_l).clamp(min=self.epsilon) return 1 - (2 * (intersect.sum() / denominator.sum())) GeneralizedDiceLoss(normalization='softmax').dice(prediction=out, target=seg, weight=None)
I get a dice score with the same
Though I don’t know how GDL varies compared to the V-net Dice but can they be compared to some extent in terms of optimization?
Also, how do I use this in training?
img, seg = img.to(device), seg.to(device) out = model(img) optimizer.zero_grad() loss = dice_loss_3d(out, seg).cuda() # print(loss) loss.backward() optimizer.step() running_loss += loss.item()