I am to trying to learn how to train a network for segmentation in pytorch.

I run into the problem of selecting a proper loss function and thus constructing a proper training loop.

The masks of the images are `{0,1}`

-valued matrices, each channel represents different label (4 channels total). Which loss function should I use for this? Should it apply separately to every channel (class) and then be averaged?

I have this training loop for a UNet model:

```
for epoch in range(total):
pbar = tqdm(trainloader);
model.train()
running_loss=0.
for (x, mask) in pbar:
y_pred = torch.softmax(model(x.cuda()), dim=1)
loss = criterion(y_pred,
mask.cuda())
running_loss += (loss.item())/len(trainloader)
pbar.set_postfix({'Epoch': f'{epoch+1}/{total}',
'current_loss':f'{loss.item():.2f}'})
loss.backward()
if (i+1)%acc_steps==0: # accumulation of gradient
optimizer.step()
optimizer.zero_grad()
gc.collect()
torch.cuda.empty_cache()
i+=1
#print(running_loss)
scheduler.step(running_loss)
```

and the loss i am using now is the dice coefficient loss per channel and then averaged

```
class DiceLoss(nn.Module):
def __init__(self, eps=1e-15):
super().__init__()
self.eps = eps
def forward(self, y_pred, y_true):
'''
y_true: mask, torch.Tensor of shape NxCxHxW
y_pred: prediction, torch.Tensor of shape NxCxHxW
Here C is number of classes (4)
'''
num_classes = y_true.size(1)
bs = y_true.size(0)
dice_coef = 0.
for cls in range(num_classes):
intersection = y_true[:,cls].float() * y_pred[:,cls].float()
union = y_true[:,cls].float() + y_pred[:,cls].float()
dice_coef += ((2*intersection.sum() + self.eps)/(union.sum()+self.eps))
dice_coef/=len(range(num_classes))
return 1.-dice_coef #maximize dice coef by minimizing this value
```

Is this a correct approach for the Dice coefficient-based loss function and the training loop overall?

Thanks for any advice.