Using cross entropy loss with semantic segmentation model

If my model gives outputs in the shape of [N, C, H, W], where N is the batch size, and C are the number of channels based on the number of output classes, and I have corresponding masks in the shape of [N, H, W], am I okay to just plug these as-is into a CrossEntropyLoss function? Or do I need to do some sort of arg-max operation on the model output to get it into the exact same shape as the mask? The documentation for the function says it supports 2D pixel-wise loss, so I think this is fine, but Iā€™m trying to be 100% sure (trying to debug why my segmentation model performance is atrocious).

1 Like

you can just use CrossEntropyLoss in this case, and it will treat C as the slicing dimension for softmax, i.e. each [N, H, W] slice will have a softmax be taken over it

3 Likes

Great, thanks so much!