I assume you have already found suitable code snippets for a binary segmentation use case?
If so, you could use it as a base line and make a few changes for a multi class segmentation use case:
- use
nn.CrossEntropyLossas your criterion - your model should output a tensor with the shape
[batch_size, nb_classes, height ,width] - the
targetshould be aLongTensorwith the shape[batch_size, height, width]and contain the class indices for each pixel location in the range[0, nb_classes-1]
Depending on the format of your segmentation mask images, you might need to create a mapping e.g. between color codes and the corresponding class indices.
Let us know, if and where you get stuck.