Error while loading a dataset from torchvision.datasets

I am trying to load a MNIST from torchvision.datasets. As shown here pytorch website .

dataset = torchvision.datasets.MNIST(‘data_nn/’)
data_loader = torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=True,num_workers=4)

But when I am going to iterate over the data_loader, that is
“”"
for image,label in data_loader():
print(image)
“”"

I get the following error:

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class ‘PIL.Image.Image’>

So it is clear to me what is happening but I don’t really know how to fix it.

Something I could is dataset.data and then pass that to data_loader but that only gives me images without labels.

Thanks!

You would need to convert the PIL images to a Tensor. There is an inbuilt function in pytorch that does this, which is shown below.

dataset = torchvision.datasets.MNIST("data_nn/",transform=torchvision.transforms.ToTensor()) 

Also change the data loader part as shown below:

for step,(img,label) in enumerate(data_loader):
    print(img.shape)
1 Like