Hi!
I am trying to extract an image and plot it using matplotlib.pyplot.imshow()
but I am not able to do so. Here is my Dataset class:
class MelanomaDataset(Dataset):
def __init__(self, df, meta_features=None, transforms=None):
self.df = df.reset_index(drop=True)
self.meta_features = meta_features
self.transforms = transforms
def __len__(self):
return len(self.df)
def __getitem__(self, index):
img_name = self.df.iloc[index].image_name
img_path = os.path.join('/kaggle/input/siim-isic-melanoma-classification/jpeg/train', img_name + '.jpg')
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transforms is not None:
res = self.transforms(image=image)
image = res['image'].astype(np.float32)
else:
image = image.astype(np.float32)
image = image.transpose(2, 0, 1)
if self.meta_features is not None:
img = torch.tensor(image).float()
meta_feats = torch.tensor(
self.df.iloc[index][self.meta_features]
).float()
else:
img = torch.tensor(image).float()
meta_feats = None
return {
'image': img,
'meta_features': meta_feats,
'target': torch.tensor(self.df.iloc[index].target).long()
}
And this is how I am trying to plot a sample from the dataset:
plt.imshow(dataset[1]["image"].permute(1, 2, 0))
It would be great if someone could help out with this problem.
Thanks!