I am training a UNet for multi-class segmentations. The classes I have are 0,1,2 which correspond to the background, object, and tip classes respectively. In binary segmentation, I could simply use the torchmetrics
function dice
inside the training loop to calculate the dice score between each prediction and the target, then average that per epoch to get the output. How can such a definition be generalized to a multi-class problem in Pytorch?