DataLoader not reading in masks correctly?

I think you are spot on.

I am using nn.CrossEntropyLoss as my criterion. So my label format is (N, H, W). Multiclass Segmentation - #2 by ptrblck.

So does this mean I should go back to the following mapping?

        mask = torch.zeros(h, w, dtype = torch.long)
        colors = torch.unique(target.view(-1, target.size(2)), dim=0).numpy()
        target = target.permute(2, 0, 1).contiguous()
        mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
        print(f"target shape: {target.shape}\n")

        for k in mapping:
            # Get all indices for current class
            idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
            validx = (idx.sum(0) == 3)  # Check that all channels match
            mask[validx] = torch.tensor(mapping[k], dtype=torch.long)

So that way we have the class index in each pixel? Then when it’s time to output the prediction at the end of the model, we just convert back to RGB like how you demonstrated in my other post?