# Training loop for segmentation

I am to trying to learn how to train a network for segmentation in pytorch.
I run into the problem of selecting a proper loss function and thus constructing a proper training loop.
The masks of the images are `{0,1}`-valued matrices, each channel represents different label (4 channels total). Which loss function should I use for this? Should it apply separately to every channel (class) and then be averaged?
I have this training loop for a UNet model:

``````for epoch in range(total):
model.train()
running_loss=0.
y_pred = torch.softmax(model(x.cuda()), dim=1)
loss = criterion(y_pred,
pbar.set_postfix({'Epoch': f'{epoch+1}/{total}',
'current_loss':f'{loss.item():.2f}'})
loss.backward()
if (i+1)%acc_steps==0: # accumulation of gradient
optimizer.step()
gc.collect()
torch.cuda.empty_cache()
i+=1
#print(running_loss)
scheduler.step(running_loss)
``````

and the loss i am using now is the dice coefficient loss per channel and then averaged

``````class DiceLoss(nn.Module):
def __init__(self, eps=1e-15):
super().__init__()
self.eps = eps

def forward(self, y_pred, y_true):
'''
y_true: mask, torch.Tensor of shape NxCxHxW
y_pred: prediction, torch.Tensor of shape NxCxHxW

Here C is number of classes (4)
'''
num_classes = y_true.size(1)
bs = y_true.size(0)
dice_coef = 0.
for cls in range(num_classes):
intersection = y_true[:,cls].float() * y_pred[:,cls].float()
union = y_true[:,cls].float() + y_pred[:,cls].float()

dice_coef += ((2*intersection.sum() + self.eps)/(union.sum()+self.eps))
dice_coef/=len(range(num_classes))
return 1.-dice_coef #maximize dice coef by minimizing this value
``````

Is this a correct approach for the Dice coefficient-based loss function and the training loop overall?

Alternatively to the dice loss you could also use e.g. `nn.CrossentropyLoss`, if each pixel belongs to one certain class (note that you would need to pass the logits to this loss function).

The training loop looks alright. I’m wondering, why you are calling `empty_cache()` inside the loop, as this could make your code slower and shouldn’t save any memory.

1 Like

Would crossentropy loss work if the target mask has shape `Nx4xHxW`? I thought that the mask for cross entropy needs to be `NxHxW` only.

I will try it.
I use `empty_cache()` because I actually thought that it does save up some gpu memory, the images are very large (and the batch size is very small, hence grad accumulation).

If the target is currently a one-hot encoded tensor, you should pass `torch.argmax(target, 1)` as the target to `nn.CrossEntropyLoss`.

1 Like