Hi,
I have seen this and this post, based on them I wrote the following code:
def init_cifar_dataloader(root, batchSize):
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transform_train = transforms.Compose([
MyPatch(32,32,3),
transforms.ToTensor(),
normalize
])
transform_test = transforms.Compose([
MyPatch(32,32,3),
transforms.ToTensor(),
normalize
])
dataset = dset.ImageFolder(root="test/0",transform=transform_train)
print('dataset',len(dataset))
trainset, testset = torch.utils.data.random_split(dataset, [8, 8])
print(type(trainset))
train_loader = DataLoader(trainset,
batch_size=batchSize, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(testset,
batch_size=batchSize * 8, shuffle=False, num_workers=4, pin_memory=True)
print(f'val set: {len(test_loader.dataset)}')
dataiter = iter(train_loader)
images, labels = dataiter.next()
# print('trainloader',train_loader)
# patch = patches(images, 32, 32, 3)
# print(images.shape, patch.shape)
# imshow(make_grid(patch[0:16,:,:]))
return train_loader, test_loader
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
def patches(img, kernel, stride, channels):
print(type(img))
img = transforms.ToTensor()(img).unsqueeze(0)
patches = img.unfold(2,kernel, stride).unfold(3,kernel, stride)
patches = patches.contiguous().view(-1, channels, kernel, kernel)
patches = transforms.ToPILImage()(patches.squeeze_(0))
return patches
class MyPatch(object):
def __init__(self, kernel, stride, channels):
self.kernel = kernel
self.stride = stride
self.channels = channels
def __call__(self,img):
return patches(img, self.kernel, self.stride, self.channels)
However I get TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>'
Full error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-190-d3a2cf03c1d9> in <module>
83
84
---> 85 init_cifar_dataloader('test',8)
<ipython-input-190-d3a2cf03c1d9> in init_cifar_dataloader(root, batchSize)
48
49 dataiter = iter(train_loader)
---> 50 images, labels = dataiter.next()
51
52 print('trainloader',train_loader)
/apps/python3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/apps/python3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
854 else:
855 del self._task_info[idx]
--> 856 return self._process_data(data)
857
858 def _try_put_index(self):
/apps/python3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
879 self._try_put_index()
880 if isinstance(data, ExceptionWrapper):
--> 881 data.reraise()
882 return data
883
/apps/python3/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
393 # (https://bugs.python.org/issue2651), so we work around it.
394 msg = KeyErrorMessage(msg)
--> 395 raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/apps/python3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/apps/python3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/apps/python3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/apps/python3/lib/python3.7/site-packages/torch/utils/data/dataset.py", line 257, in __getitem__
return self.dataset[self.indices[idx]]
File "/apps/python3/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 137, in __getitem__
sample = self.transform(sample)
File "/apps/python3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
img = t(img)
File "<ipython-input-190-d3a2cf03c1d9>", line 82, in __call__
return patches(img, self.kernel, self.stride, self.channels)
File "<ipython-input-190-d3a2cf03c1d9>", line 71, in patches
patches = transforms.ToPILImage()(patches.squeeze_(0))
File "/apps/python3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 127, in __call__
return F.to_pil_image(pic, self.mode)
File "/apps/python3/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 104, in to_pil_image
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.
Any help is welcome!