Dice learning loss unchanging after first epoch

I’ve implemented a custom multi-class 2D dice loss function to train a Unet for image segmentation referenced from online code, but when used in training the test loss never changes, and the training loss just oscillates. the Unet works when using Cross-Entropy to train, so I believe the problem is with the custom function, but I cannot find why it behaves this way. From looking into the training, I saw that the loss does change during training, but it rapidly converges, and when validating the loss remains a constant, 0.1604 in the most recent case, and just never changes from it over any number of epochs.

I have also taken other implementations of multi-class dice loss functions, and have faced the same problem. Any clue what could be causing this issue?

class DiceLoss2D(nn.Module):
def init(self, classes, epsilon=1e-5, sigmoid_normalization=True):
super(DiceLoss2D, self).init()
self.epsilon = epsilon
self.classes = classes

    if sigmoid_normalization:
        self.normalization = nn.Sigmoid()
    else:
        self.normalization = nn.Softmax(dim=1)  # TODO test ?

def flatten(self, tensor):
    return tensor.view(self.classes, -1)

def expand_as_one_hot(self, target):
    """
    Converts label image to CxHxW, where each label gets converted to
    its corresponding one-hot vector
    :param target is of shape  (1xHxW)
    :return: 3D output tensor (CxHxW) where C is the classes
    """
    # target = target.squeeze()
    shape = target.size()
    shape = list(shape)
    shape.insert(1, self.classes)
    shape = tuple(shape)
    # expand the input tensor to Nx1xHxW
    src = target.unsqueeze(1)
    srcshape = src.size()
    return torch.zeros(shape).to(target.device).scatter_(1, src, 1).squeeze(0)

def compute_per_channel_dice(self, input, target):
    epsilon = 1e-5
    One_hot_target = self.expand_as_one_hot(target.long())
    # target: BxCxWxH

    input = input.squeeze()
    assert input.size() == One_hot_target.size(), "input' and 'target' must have the same shape"+ str(input.size()) + " and " + str(target.size())

    input = self.flatten(input)
    One_hot_target = self.flatten(One_hot_target).float()

    # Compute per channel Dice Coefficient
    intersect = (input * One_hot_target).sum(-1)
    denominator = (input + One_hot_target).sum(-1)
    return 2. * intersect / denominator.clamp(min=epsilon)

def forward(self, input, target):
    # input: BxCxWxH
    # target: BxWxH
    probinput = self.normalization(input)
    per_channel_dice = self.compute_per_channel_dice(probinput, target)
    # Average the Dice score across all channels/classes
    return torch.mean(1. - per_channel_dice)