class DICELossMultiClass(nn.Module):

```
def __init__(self):
super(DICELossMultiClass, self).__init__()
def forward(self, output, mask):
num_classes = output.size(-1)
dice_eso = 0
for i in range(num_classes):
probs = torch.squeeze(output[:, :, :, :, i], -1) # [batch_size, D, H, W, channel]
target = torch.squeeze(mask[:, :, :, :, i], -1) # one-hot :[batch_size, D, H, W, channel]
num = probs * target
num = torch.sum(num, 3)
num = torch.sum(num, 2)
num = torch.sum(num, 1)
# print( num )
den1 = probs * probs
# print(den1.size())
den1 = torch.sum(den1, 3)
den1 = torch.sum(den1, 2)
den1 = torch.sum(den1, 1)
# print(den1.size())
den2 = target * target
# print(den2.size())
den2 = torch.sum(den2, 3)
den2 = torch.sum(den2, 2)
den2 = torch.sum(den2, 1)
# print(den2.size())
eps = 0.0000001
dice = 2 * ((num + eps) / (den1 + den2 + eps))
# dice_eso = dice[:, 1:]
dice_eso += dice
loss = 1 - torch.sum(dice_eso) / dice_eso.size(0)
return loss
```

I used the above function to computed the multi-label dice loss(the target label is encoded by one-hot), but I don’t know whether it is correct because I got some negative value for loss. I think the main reason is that one of channels is for background, which is set to 1 for background pixels,this is different from the binary semantic segmentation.

What can I do for this ?