Visualise the test images after training the model on segmentation task

I have trained the model for multi-class image segmentation. My model is returning the torch tensor of shape torch.Size([4, 512, 512]). Now, I want to convert this tensor to the image mask, as I want to visualise the output my model is predicting but I am unable to do it
Following is my image_to_class function which converts mask image to class labels-

def mask_to_class(self,mask):
    target = torch.from_numpy(mask)
    h,w = target.shape[0],target.shape[1]
    masks = torch.empty(h, w, dtype=torch.long)
    colors = torch.unique(target.view(-1,target.size(2)),dim=0).numpy()
    target = target.permute(2, 0, 1).contiguous()
    mapping = {tuple(c): t for c, t in zip(colors.tolist(), range(len(colors)))}
    for k in mapping:
        idx = (target==torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
        validx = (idx.sum(0) == 3) 
        masks[validx] = torch.tensor(mapping[k], dtype=torch.long)
    return masks
    

Below is my predict_mask function which i am not able to finish it up-

def predict_mask(img_path,model):
    img = Image.open(img_path)
    
    transform = transforms.Compose([transforms.Resize((512,512)),
                                    transforms.ToTensor()])
    img_transformed = transform(img).unsqueeze_(0)
    if train_on_gpu:
        img_transformed = img_transformed.cuda()
    if train_on_gpu:
        model = model.cuda()
    output = model(img_transformed)
    return output

Right now, it is returning the tensor of shape torch.Size([1, 4, 512, 512]), how to convert it in the required format which should be a rgb image?

You could get the predicted class indices by using pred = torch.argmax(output, 1).
Then you could apply the reversed mapping to create an RGB image using the same colors as your segmentation masks.

1 Like

I did this then I am pretty much lost what to do.

def predict_fundus_image(img_path,model):
    img = Image.open(img_path)
    
    transform = transforms.Compose([transforms.Resize((512,512)),
                                    transforms.ToTensor()])
    img_transformed = transform(img)
    img_transformed = transform(img).unsqueeze_(0)
    if train_on_gpu:
        img_transformed = img_transformed.cuda()
    if train_on_gpu:
        model = model.cuda()
    output = model(img_transformed)
    output = torch.argmax(output,1)
    return output

img = predict_fundus_image(img_path,model)
mapping = {(0, 0, 0): 0, (0, 0, 255): 1, (255, 0, 0): 2, (255, 255, 255): 3}
rev_map = {v: k for k, v in mapping.items()}

Assuming you are using the dummy code from here, you could do the following:

rev_mapping = {mapping[k]: k for k in mapping}
pred = mask # or e.g. pred = torch.randint(0, 19, (224, 224))
pred_image = torch.zeros(3, pred.size(0), pred.size(1), dtype=torch.uint8)
for k in rev_mapping:
    print(k)
    pred_image[:, pred==k] = torch.tensor(rev_mapping[k]).byte().view(3, 1)

plt.imshow(pred_image.permute(1, 2, 0).numpy())
2 Likes

okay, thanks a lot but where we used the outcome of the trained model here?

pred would be calculated as torch.argmax(output, 1).

1 Like

Thanks a lot. I did the following and obtaining a black image only so I guess my model sucks and accuracy is really not the right metric to judge unbalanced datasets.

def predict_fundus_image(img_path,model):
    img = Image.open(img_path)
    
    transform = transforms.Compose([transforms.Resize((512,512)),
                                    transforms.ToTensor()])
    img_transformed = transform(img)
    img_transformed = transform(img).unsqueeze_(0)
    if train_on_gpu:
        img_transformed = img_transformed.cuda()
    if train_on_gpu:
        model = model.cuda()
    output = model(img_transformed)
    output = torch.argmax(output,1)
    return output

img = predict_fundus_image(img_path,model)
img = img.squeeze(0)
mapping = {(0, 0, 0): 0, (0, 0, 255): 1, (255, 0, 0): 2, (255, 255, 255): 3}
rev_mapping = {mapping[k]: k for k in mapping}
pred = img
pred_image = torch.zeros(3, img.size(0), img.size(1), dtype=torch.uint8)
for k in rev_mapping:
    print(k)
    pred_image[:, pred==k] = torch.tensor(rev_mapping[k]).byte().view(3, 1)
plt.imshow(pred_image.permute(1, 2, 0).numpy())

You could check the number of uniquely predicted classes with print(pred.unqiue(return_counts=True)).
This would probably output a high number of the background class and very few other classes.

1 Like

out_num = output.argmax(2)[-1].item()
Can you please explain me the line

This line of code calculates the argmax (index of the max. value) in dimension 2 of output, indexes the last dimension (the dim used in argmax will be removed), and converts this value to a Python int value.

I think the better way to see, what exactly is returned would be to execute the operations sequentially and to print out the intermediate return values.

@ptrblck Does the output need to be passed through a softmax before this pred = torch.argmax(output, 1) ? I was doing segmentation with 4 classes as well, my model doesn’t have a softmax at the final layer and I use cross entropy loss

No, you don’t need to apply the softmax activation to get the predictions, since the logits would already give you the predicted class. Softmax would only normalize the values to probabilities and you would get the same answer:

x = torch.randn(10, 10)
print(torch.argmax(x, dim=1))
print(torch.argmax(F.softmax(x, dim=1), dim=1))
1 Like

Finally, I got my segmentation model working perfectly fine following all your guides in this forum. Thank you so much!! @ptrblck , you’re a hero XDD

1 Like