How to select a loss function for 3D segmentation networks?

I need to train a net using 3D images with dimension BatchChannelDepthHeightWidth,and the dimension of output and label are BDH*W,but I can’t find a proper loss function from torch.nn.Loss functions.
Can you give me some suggestions?Thank you!

Hi, depending on the problem we’ve had good luck with CrossEntropyLoss as well as Dice.
Assuming you have your model producing a two-channel probability map, you should create a 1D view of the 3D images and the respective mask this way:

criterion = nn.CrossEntropyLoss()
output = model(input)
output = output.permute(0,2,3,4,1).contiguous()
output = output.view(output.numel() // 2, 2)
mask = mask.view(-1)
loss = criterion(output, mask)

For the Dice index, here’s a 2D implementation. I can’t send you the 3D version right now as I’m on the go:

1 Like

@lantiga: Can you send the 3D dice now?