Hello, I would like to have the MNIST data loader (torchvision package) to return an extra item for eg, Mean pixel value of the image along with the Image and the target. I have created a subclass which override the get_item method and returns the extra item.
import torchvision.datasets as dataset
from torch.utils.data import DataLoader
class MNIST1(dataset.MNIST):
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target,mean_pixel) where target is index of the target class.
"""
img, target = super(MNIST1,self).__getitem__(index)
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
# the exta item to be returned
mean_pixel = PIL.ImageStat.Stat(img).mean
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
sample ={"image":img,"target":target,"mean_pixel",mean_pixel}
return sample
the data loading part…
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = MNIST1(root=root, train=True, transform=trans, download=True)
test_set = MNIST1(root=root, train=False, transform=trans, download=True)
batch_size = 100
train_loader = DataLoader(dataset=train_set,
batch_size=batch_size,
shuffle=True,pin_memory=True)
test_loader = DataLoader(dataset=test_set,
batch_size=batch_size,
shuffle=True,pin_memory=True)
Am i doing it correct , I am facing errors while trying to iterate over the dataloader.
ValueError: Too many dimensions: 3 > 2.
this is from Image.fromarray(img.numpy(), mode=‘L’) part. Please suggest a solution