Not able to plot an image from a Pytorch Dataset

Hi!
I am trying to extract an image and plot it using matplotlib.pyplot.imshow() but I am not able to do so. Here is my Dataset class:

class MelanomaDataset(Dataset):
    def __init__(self, df, meta_features=None, transforms=None):
        self.df = df.reset_index(drop=True)
        self.meta_features = meta_features
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        img_name = self.df.iloc[index].image_name
        img_path = os.path.join('/kaggle/input/siim-isic-melanoma-classification/jpeg/train', img_name + '.jpg')

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transforms is not None:
            res = self.transforms(image=image)
            image = res['image'].astype(np.float32)
        else:
            image = image.astype(np.float32)

        image = image.transpose(2, 0, 1)

        if self.meta_features is not None:
            img = torch.tensor(image).float()
            meta_feats = torch.tensor(
                self.df.iloc[index][self.meta_features]
            ).float()
        else:
            img = torch.tensor(image).float()
            meta_feats = None
        return {
            'image': img,
            'meta_features': meta_feats,
            'target': torch.tensor(self.df.iloc[index].target).long()
        }

And this is how I am trying to plot a sample from the dataset:

plt.imshow(dataset[1]["image"].permute(1, 2, 0))

It would be great if someone could help out with this problem.
Thanks!

What kind of error are you seeing?

I’m not getting an error. Just the figure that is being printed is white, so the img is not being printed. Here is a screenshot:

I guess you are ignoring a warning explaining you might be clipping the image?

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

I am not getting that warning I just double checked.

In that case I wouldn’t know what’s causing the issue as your code works fine after removing undefined parts and casing the output back to uint8:

class MelanomaDataset(Dataset):
    def __init__(self):
        self.paths = ["image.jpeg"]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img_path = self.paths[index]
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = image.transpose(2, 0, 1)
        image = torch.from_numpy(image).float()

        return {
            'image': image,
        }
    
dataset = MelanomaDataset()
data = dataset[0]
img = data["image"]
plt.imshow(img.permute(1, 2, 0).numpy().astype(np.uint8))

That’s strange. Thanks a lot tho :slight_smile:

Could you double check the dtype and value range of your data["image"] output and check the warning I’ve posted?
Make sure that floating point numbers are in [0, 1] and ints in [0, 255].