I have .npy
dataset file and It works fine when transform=None
.
But when I apply transform=transform_train
,
TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>
occurs.
How can I fix it?
import numpy as np
import os
class CIFAR10:
def __init__(self, root='/database', split="l_train", transform=None):
self.dataset = np.load(os.path.join(root, "cifar10", split + ".npy"), allow_pickle=True).item()
self.transform = transform
def __getitem__(self, idx):
image = self.dataset["images"][idx]
label = self.dataset["labels"][idx]
# print(type(image)) # <class 'numpy.ndarray'>
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.dataset["images"])
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(size=32),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
trainset_x = CIFAR10(split='l_train', transform=transform_train)