Visualise the test images after training the model on segmentation task

Assuming you are using the dummy code from here, you could do the following:

rev_mapping = {mapping[k]: k for k in mapping}
pred = mask # or e.g. pred = torch.randint(0, 19, (224, 224))
pred_image = torch.zeros(3, pred.size(0), pred.size(1), dtype=torch.uint8)
for k in rev_mapping:
    print(k)
    pred_image[:, pred==k] = torch.tensor(rev_mapping[k]).byte().view(3, 1)

plt.imshow(pred_image.permute(1, 2, 0).numpy())
2 Likes