class MyDataset(torch.utils.data.Dataset):
def __init__(self,dataset = None, transform= None):
self.MNIST = dataset
#print(self.MNIST)
self.transform = transform
def __getitem__(self, index):
data, target = self.MNIST[index]
#print(img.shape)
if self.transform is not None:
#print(data.size())
#tr_data = TF.to_pil_image(data)
#print(tr_data.size())
#tr_data = TF.hflip(tr_data) #### 10 degree rotations
#print(tr_data.size())
#tr_data = TF.to_tensor(tr_data)
tr_data = self.transform(data)
return data, tr_data, target, index
def __len__(self):
return len(self.MNIST)
train_dataset = datasets.MNIST(root=’./data’, train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]))
trin_transform=transforms.Compose([transforms.RandomHorizontalFlip(p=1),
transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])
trainset = MyDataset(dataset=train_dataset, transform = trin_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=24, shuffle=False, num_workers=2)
The error is coming like “TypeError: img should be PIL Image. Got <class ‘torch.Tensor’>”