I’ve implemented a custom multi-class 2D dice loss function to train a Unet for image segmentation referenced from online code, but when used in training the test loss never changes, and the training loss just oscillates. the Unet works when using Cross-Entropy to train, so I believe the problem is with the custom function, but I cannot find why it behaves this way. From looking into the training, I saw that the loss does change during training, but it rapidly converges, and when validating the loss remains a constant, 0.1604 in the most recent case, and just never changes from it over any number of epochs.
I have also taken other implementations of multi-class dice loss functions, and have faced the same problem. Any clue what could be causing this issue?
class DiceLoss2D(nn.Module):
def init(self, classes, epsilon=1e-5, sigmoid_normalization=True):
super(DiceLoss2D, self).init()
self.epsilon = epsilon
self.classes = classes
if sigmoid_normalization:
self.normalization = nn.Sigmoid()
else:
self.normalization = nn.Softmax(dim=1) # TODO test ?
def flatten(self, tensor):
return tensor.view(self.classes, -1)
def expand_as_one_hot(self, target):
"""
Converts label image to CxHxW, where each label gets converted to
its corresponding one-hot vector
:param target is of shape (1xHxW)
:return: 3D output tensor (CxHxW) where C is the classes
"""
# target = target.squeeze()
shape = target.size()
shape = list(shape)
shape.insert(1, self.classes)
shape = tuple(shape)
# expand the input tensor to Nx1xHxW
src = target.unsqueeze(1)
srcshape = src.size()
return torch.zeros(shape).to(target.device).scatter_(1, src, 1).squeeze(0)
def compute_per_channel_dice(self, input, target):
epsilon = 1e-5
One_hot_target = self.expand_as_one_hot(target.long())
# target: BxCxWxH
input = input.squeeze()
assert input.size() == One_hot_target.size(), "input' and 'target' must have the same shape"+ str(input.size()) + " and " + str(target.size())
input = self.flatten(input)
One_hot_target = self.flatten(One_hot_target).float()
# Compute per channel Dice Coefficient
intersect = (input * One_hot_target).sum(-1)
denominator = (input + One_hot_target).sum(-1)
return 2. * intersect / denominator.clamp(min=epsilon)
def forward(self, input, target):
# input: BxCxWxH
# target: BxWxH
probinput = self.normalization(input)
per_channel_dice = self.compute_per_channel_dice(probinput, target)
# Average the Dice score across all channels/classes
return torch.mean(1. - per_channel_dice)