Issue while using albumentation for image augmentation



TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/content/dataset.py", line 27, in __getitem__
    image = plt.imread(self.image_list[i])
  File "/usr/local/lib/python3.6/dist-packages/matplotlib/pyplot.py", line 2135, in imread
    return matplotlib.image.imread(fname, format)
  File "/usr/local/lib/python3.6/dist-packages/matplotlib/image.py", line 1436, in imread
    return handler(fname)
  File "/usr/local/lib/python3.6/dist-packages/matplotlib/image.py", line 1390, in read_png
    return _png.read_png(*args, **kwargs)
TypeError: Object does not appear to be a 8-bit string path or a Python file-like object




=================================================================================
import torch
import torchvision
import torchvision.transforms as transforms
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset

class AlbumentationImageDataset(Dataset):
  def __init__(self, image_list):
      self.image_list = image_list

      self.aug = A.Compose({
      A.Resize(200, 300),
      A.CenterCrop(100, 100),
      A.RandomCrop(80, 80),
      A.HorizontalFlip(p=0.5),
      A.Rotate(limit=(-90, 90)),
      A.VerticalFlip(p=0.5),
      A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
      })
       
  def __len__(self):
      return (len(self.image_list))

  def __getitem__(self, i):
      image = plt.imread(self.image_list[i])
      image = Image.fromarray(image).convert('RGB')
      image = self.aug(image=np.array(image))['image']
      image = np.transpose(image, (2, 0, 1)).astype(np.float32)
          
      return torch.tensor(image, dtype=torch.float)


def getData():
  transform_test = transforms.Compose(
      [transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
      transforms.RandomRotation((-10.0, 10.0)), transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  ])
  trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True)
  print(trainset)
  alb_dataset = AlbumentationImageDataset(image_list=trainset)

  trainloader = DataLoader(alb_dataset, batch_size= 128,
                                            shuffle=True, num_workers=2)

  testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform_test)
  testloader = DataLoader(testset, batch_size=100,
                                          shuffle=False, num_workers=2)

  classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  return trainloader, testloader, classes
1 Like

The error seems to be raised by matplotlib while trying to read the image.
What kind of image format are you using?