I think you are spot on.
I am using nn.CrossEntropyLoss
as my criterion. So my label format is (N, H, W). Multiclass Segmentation - #2 by ptrblck.
So does this mean I should go back to the following mapping?
mask = torch.zeros(h, w, dtype = torch.long)
colors = torch.unique(target.view(-1, target.size(2)), dim=0).numpy()
target = target.permute(2, 0, 1).contiguous()
mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
print(f"target shape: {target.shape}\n")
for k in mapping:
# Get all indices for current class
idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
validx = (idx.sum(0) == 3) # Check that all channels match
mask[validx] = torch.tensor(mapping[k], dtype=torch.long)
So that way we have the class index in each pixel? Then when it’s time to output the prediction at the end of the model, we just convert back to RGB like how you demonstrated in my other post?