I’ve implemented a custom multi-class 2D dice loss function to train a Unet for image segmentation referenced from online code, but when used in training the test loss never changes, and the training loss just oscillates. the Unet works when using Cross-Entropy to train, so I believe the problem is with the custom function, but I cannot find why it behaves this way. From looking into the training, I saw that the loss does change during training, but it rapidly converges, and when validating the loss remains a constant, 0.1604 in the most recent case, and just never changes from it over any number of epochs.
I have also taken other implementations of multi-class dice loss functions, and have faced the same problem. Any clue what could be causing this issue?
def init(self, classes, epsilon=1e-5, sigmoid_normalization=True):
self.epsilon = epsilon
self.classes = classes
if sigmoid_normalization: self.normalization = nn.Sigmoid() else: self.normalization = nn.Softmax(dim=1) # TODO test ? def flatten(self, tensor): return tensor.view(self.classes, -1) def expand_as_one_hot(self, target): """ Converts label image to CxHxW, where each label gets converted to its corresponding one-hot vector :param target is of shape (1xHxW) :return: 3D output tensor (CxHxW) where C is the classes """ # target = target.squeeze() shape = target.size() shape = list(shape) shape.insert(1, self.classes) shape = tuple(shape) # expand the input tensor to Nx1xHxW src = target.unsqueeze(1) srcshape = src.size() return torch.zeros(shape).to(target.device).scatter_(1, src, 1).squeeze(0) def compute_per_channel_dice(self, input, target): epsilon = 1e-5 One_hot_target = self.expand_as_one_hot(target.long()) # target: BxCxWxH input = input.squeeze() assert input.size() == One_hot_target.size(), "input' and 'target' must have the same shape"+ str(input.size()) + " and " + str(target.size()) input = self.flatten(input) One_hot_target = self.flatten(One_hot_target).float() # Compute per channel Dice Coefficient intersect = (input * One_hot_target).sum(-1) denominator = (input + One_hot_target).sum(-1) return 2. * intersect / denominator.clamp(min=epsilon) def forward(self, input, target): # input: BxCxWxH # target: BxWxH probinput = self.normalization(input) per_channel_dice = self.compute_per_channel_dice(probinput, target) # Average the Dice score across all channels/classes return torch.mean(1. - per_channel_dice)