TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/content/dataset.py", line 27, in __getitem__
image = plt.imread(self.image_list[i])
File "/usr/local/lib/python3.6/dist-packages/matplotlib/pyplot.py", line 2135, in imread
return matplotlib.image.imread(fname, format)
File "/usr/local/lib/python3.6/dist-packages/matplotlib/image.py", line 1436, in imread
return handler(fname)
File "/usr/local/lib/python3.6/dist-packages/matplotlib/image.py", line 1390, in read_png
return _png.read_png(*args, **kwargs)
TypeError: Object does not appear to be a 8-bit string path or a Python file-like object
=================================================================================
import torch
import torchvision
import torchvision.transforms as transforms
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset
class AlbumentationImageDataset(Dataset):
def __init__(self, image_list):
self.image_list = image_list
self.aug = A.Compose({
A.Resize(200, 300),
A.CenterCrop(100, 100),
A.RandomCrop(80, 80),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=(-90, 90)),
A.VerticalFlip(p=0.5),
A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
})
def __len__(self):
return (len(self.image_list))
def __getitem__(self, i):
image = plt.imread(self.image_list[i])
image = Image.fromarray(image).convert('RGB')
image = self.aug(image=np.array(image))['image']
image = np.transpose(image, (2, 0, 1)).astype(np.float32)
return torch.tensor(image, dtype=torch.float)
def getData():
transform_test = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
transforms.RandomRotation((-10.0, 10.0)), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True)
print(trainset)
alb_dataset = AlbumentationImageDataset(image_list=trainset)
trainloader = DataLoader(alb_dataset, batch_size= 128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
return trainloader, testloader, classes
1 Like
The error seems to be raised by matplotlib
while trying to read the image.
What kind of image format are you using?