Creating custom dataset class to do data augmentation on numpy arrays - getting an error

Hi, I am a beginner in pytorch. I am trying to define a custom dataset class to transform my data from numpy arrays to PIL imaages to do augmentations. However I am getting an error.
My data is: X: numpy array of shape (383, 1, 1000, 100) (383 single channel images 1000x100)
y: numpy array with the labels.
I am using the following code:

`batchsize = 2

transform = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.Grayscale(num_output_channels=3),
     transforms.RandomRotation(degrees = 45),
     transforms.RandomHorizontalFlip,
     transforms.RandomVerticalFlip,
     transforms.ToTensor()
])

class MyDataset(Dataset):
    def __init__(self, data, target, transform=transform):
        self.data = torch.from_numpy(data).float()
        self.target = torch.from_numpy(target).long()
        self.transform = transform
     
    def __len__(self):
        return len(self.target)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        a = x.numpy()
        b = np.squeeze(a, axis=0)
        c = Image.fromarray(b)
        d = self.transform(c)
        
        return d, y
    


dataset = MyDataset(X, y)
loader = DataLoader(
    dataset,
    batch_size=batchsize,
    shuffle=True,
    num_workers=0,
    pin_memory=True,  
    worker_init_fn=_init_fn
)

for batch_idx, (data, target) in enumerate(loader):
    print('Batch idx {}, data shape {}, target shape {}'.format(
        batch_idx, data.shape, target.shape))`

I get the following error:

TypeError                                 Traceback (most recent call last)
<ipython-input-152-425dcf15a5af> in <module>
     61 )
     62 
---> 63 for batch_idx, (data, target) in enumerate(loader):
     64     print('Batch idx {}, data shape {}, target shape {}'.format(
     65         batch_idx, data.shape, target.shape))

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-152-425dcf15a5af> in __getitem__(self, index)
     27         b = np.squeeze(a, axis=0)
     28         c = Image.fromarray(b)
---> 29         d = self.transform(c)
     30 
     31         return d, y

/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
     68     def __call__(self, img):
     69         for t in self.transforms:
---> 70             img = t(img)
     71         return img
     72 

/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, pic)
    134 
    135         """
--> 136         return F.to_pil_image(pic, self.mode)
    137 
    138     def __repr__(self):

/opt/conda/lib/python3.7/site-packages/torchvision/transforms/functional.py in to_pil_image(pic, mode)
    118     """
    119     if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
--> 120         raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
    121 
    122     elif isinstance(pic, torch.Tensor):

TypeError: pic should be Tensor or ndarray. Got <class 'PIL.Image.Image'>.

While if I change my transform to:

transform = transforms.Compose(
     [
     transforms.Grayscale(num_output_channels=3),
     transforms.RandomRotation(degrees = 45),
     transforms.RandomHorizontalFlip,
     transforms.RandomVerticalFlip,
     transforms.ToTensor()
])

I get this error:

TypeError                                 Traceback (most recent call last)
<ipython-input-12-cee2846210e9> in <module>
     62 )
     63 
---> 64 for batch_idx, (data, target) in enumerate(loader):
     65     print('Batch idx {}, data shape {}, target shape {}'.format(
     66         batch_idx, data.shape, target.shape))

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-12-cee2846210e9> in __getitem__(self, index)
     28         b = np.squeeze(a, axis=0)
     29         c = Image.fromarray(b)
---> 30         d = self.transform(c)
     31 
     32         return d, y

/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
     68     def __call__(self, img):
     69         for t in self.transforms:
---> 70             img = t(img)
     71         return img
     72 

/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, pic)
     99             Tensor: Converted image.
    100         """
--> 101         return F.to_tensor(pic)
    102 
    103     def __repr__(self):

/opt/conda/lib/python3.7/site-packages/torchvision/transforms/functional.py in to_tensor(pic)
     53     """
     54     if not(_is_pil_image(pic) or _is_numpy(pic)):
---> 55         raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
     56 
     57     if _is_numpy(pic) and not _is_numpy_image(pic):

TypeError: pic should be PIL Image or ndarray. Got <class 'torchvision.transforms.transforms.RandomVerticalFlip'>

How can I fix my code?
Thanks

For the first error looks like your data is already a PIL image, so ToPILImage fails.

For the second error you forgot parenthesis. Instead of transforms.RandomHorizontalFlip, transforms.RandomVerticalFlip try using transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()

1 Like