Hi, I am a beginner in pytorch. I am trying to define a custom dataset class to transform my data from numpy arrays to PIL imaages to do augmentations. However I am getting an error.
My data is: X: numpy array of shape (383, 1, 1000, 100) (383 single channel images 1000x100)
y: numpy array with the labels.
I am using the following code:
`batchsize = 2
transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.Grayscale(num_output_channels=3),
transforms.RandomRotation(degrees = 45),
transforms.RandomHorizontalFlip,
transforms.RandomVerticalFlip,
transforms.ToTensor()
])
class MyDataset(Dataset):
def __init__(self, data, target, transform=transform):
self.data = torch.from_numpy(data).float()
self.target = torch.from_numpy(target).long()
self.transform = transform
def __len__(self):
return len(self.target)
def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
a = x.numpy()
b = np.squeeze(a, axis=0)
c = Image.fromarray(b)
d = self.transform(c)
return d, y
dataset = MyDataset(X, y)
loader = DataLoader(
dataset,
batch_size=batchsize,
shuffle=True,
num_workers=0,
pin_memory=True,
worker_init_fn=_init_fn
)
for batch_idx, (data, target) in enumerate(loader):
print('Batch idx {}, data shape {}, target shape {}'.format(
batch_idx, data.shape, target.shape))`
I get the following error:
TypeError Traceback (most recent call last)
<ipython-input-152-425dcf15a5af> in <module>
61 )
62
---> 63 for batch_idx, (data, target) in enumerate(loader):
64 print('Batch idx {}, data shape {}, target shape {}'.format(
65 batch_idx, data.shape, target.shape))
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
<ipython-input-152-425dcf15a5af> in __getitem__(self, index)
27 b = np.squeeze(a, axis=0)
28 c = Image.fromarray(b)
---> 29 d = self.transform(c)
30
31 return d, y
/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
68 def __call__(self, img):
69 for t in self.transforms:
---> 70 img = t(img)
71 return img
72
/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, pic)
134
135 """
--> 136 return F.to_pil_image(pic, self.mode)
137
138 def __repr__(self):
/opt/conda/lib/python3.7/site-packages/torchvision/transforms/functional.py in to_pil_image(pic, mode)
118 """
119 if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
--> 120 raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
121
122 elif isinstance(pic, torch.Tensor):
TypeError: pic should be Tensor or ndarray. Got <class 'PIL.Image.Image'>.
While if I change my transform to:
transform = transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.RandomRotation(degrees = 45),
transforms.RandomHorizontalFlip,
transforms.RandomVerticalFlip,
transforms.ToTensor()
])
I get this error:
TypeError Traceback (most recent call last)
<ipython-input-12-cee2846210e9> in <module>
62 )
63
---> 64 for batch_idx, (data, target) in enumerate(loader):
65 print('Batch idx {}, data shape {}, target shape {}'.format(
66 batch_idx, data.shape, target.shape))
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
<ipython-input-12-cee2846210e9> in __getitem__(self, index)
28 b = np.squeeze(a, axis=0)
29 c = Image.fromarray(b)
---> 30 d = self.transform(c)
31
32 return d, y
/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, img)
68 def __call__(self, img):
69 for t in self.transforms:
---> 70 img = t(img)
71 return img
72
/opt/conda/lib/python3.7/site-packages/torchvision/transforms/transforms.py in __call__(self, pic)
99 Tensor: Converted image.
100 """
--> 101 return F.to_tensor(pic)
102
103 def __repr__(self):
/opt/conda/lib/python3.7/site-packages/torchvision/transforms/functional.py in to_tensor(pic)
53 """
54 if not(_is_pil_image(pic) or _is_numpy(pic)):
---> 55 raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
56
57 if _is_numpy(pic) and not _is_numpy_image(pic):
TypeError: pic should be PIL Image or ndarray. Got <class 'torchvision.transforms.transforms.RandomVerticalFlip'>
How can I fix my code?
Thanks