How to extract patches from an image batch using transforms.compose

Hi,

I have seen this and this post, based on them I wrote the following code:

def init_cifar_dataloader(root, batchSize):
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    transform_train = transforms.Compose([
        MyPatch(32,32,3),
        transforms.ToTensor(),
        normalize
    ])
    transform_test = transforms.Compose([
        MyPatch(32,32,3),
        transforms.ToTensor(),
        normalize
    ])
    
    dataset = dset.ImageFolder(root="test/0",transform=transform_train)
    print('dataset',len(dataset))

    trainset, testset = torch.utils.data.random_split(dataset, [8, 8])
    print(type(trainset))

    train_loader = DataLoader(trainset,
                              batch_size=batchSize, shuffle=True, num_workers=4, pin_memory=True)
    
    test_loader = DataLoader(testset,
                             batch_size=batchSize * 8, shuffle=False, num_workers=4, pin_memory=True)
    print(f'val set: {len(test_loader.dataset)}')

    dataiter = iter(train_loader)
    images, labels = dataiter.next()
    
#     print('trainloader',train_loader)

    
#     patch = patches(images, 32, 32, 3)
#     print(images.shape, patch.shape)

#     imshow(make_grid(patch[0:16,:,:]))
    return train_loader, test_loader

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def patches(img, kernel, stride, channels):
    print(type(img))
    img = transforms.ToTensor()(img).unsqueeze(0)
    
    patches = img.unfold(2,kernel, stride).unfold(3,kernel, stride)
    patches = patches.contiguous().view(-1, channels, kernel, kernel)
    patches = transforms.ToPILImage()(patches.squeeze_(0))
    return patches   


class MyPatch(object):
    def __init__(self, kernel, stride, channels):
        self.kernel = kernel
        self.stride = stride
        self.channels = channels
    def __call__(self,img):
        return  patches(img, self.kernel, self.stride, self.channels)


However I get TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>'

Full error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-190-d3a2cf03c1d9> in <module>
     83 
     84 
---> 85 init_cifar_dataloader('test',8)

<ipython-input-190-d3a2cf03c1d9> in init_cifar_dataloader(root, batchSize)
     48 
     49     dataiter = iter(train_loader)
---> 50     images, labels = dataiter.next()
     51 
     52     print('trainloader',train_loader)

/apps/python3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/apps/python3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    854             else:
    855                 del self._task_info[idx]
--> 856                 return self._process_data(data)
    857 
    858     def _try_put_index(self):

/apps/python3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
    879         self._try_put_index()
    880         if isinstance(data, ExceptionWrapper):
--> 881             data.reraise()
    882         return data
    883 

/apps/python3/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    393             # (https://bugs.python.org/issue2651), so we work around it.
    394             msg = KeyErrorMessage(msg)
--> 395         raise self.exc_type(msg)

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/apps/python3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/apps/python3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/apps/python3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/apps/python3/lib/python3.7/site-packages/torch/utils/data/dataset.py", line 257, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/apps/python3/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 137, in __getitem__
    sample = self.transform(sample)
  File "/apps/python3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
    img = t(img)
  File "<ipython-input-190-d3a2cf03c1d9>", line 82, in __call__
    return  patches(img, self.kernel, self.stride, self.channels)
  File "<ipython-input-190-d3a2cf03c1d9>", line 71, in patches
    patches = transforms.ToPILImage()(patches.squeeze_(0))
  File "/apps/python3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 127, in __call__
    return F.to_pil_image(pic, self.mode)
  File "/apps/python3/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 104, in to_pil_image
    raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

Any help is welcome!

Hi,
I think this is the line:

The reason is that let’s say you have an image of size [3, 100, 100] and then you extract 4 patches, then patches would be [4, 3, h, w] and this cannot be converted to a single image using ToPILImage as it only accepts 2D or 3D images.

I suggest that you write your own custom transformation by accepting tensors instead of converting to tensor then extracting to patches then again to image. After all, you are again converting it to tensor to use normalize so why not directly implement patch extraction for normalized tensor?
Also, it terms of performance, this approach will be faster as you no longer have to deal with conversion overhead.

Bests