Batch of images, change dimension


I am trying to preprocess RGB images to create batches of the correct shape (as required by my network).
Example: CIFAR10

transfos = transforms.Compose([torchvision.transforms.Grayscale(),transforms.ToTensor(),transforms.Normalize((0.4749), (0.2382))])
train = datasets.CIFAR10(".", train = True, download = True, transform = transfos)
test = datasets.CIFAR10(".", train = False, download = True, transform = transfos)

batches_train =, batch_size=bs, shuffle=True)
batches_test =, batch_size=bs, shuffle=True)

The problem now is that each batch has dimension (batchsize, 1, H, W).
For all batches, I would like to have the dimension (batchsize, H*W). Is there a simple way to use transforms to do this?

Thank you!


Does something like:

>>> import torch
>>> data = torch.randn(32, 1, 32, 32)
>>> data.shape
torch.Size([32, 1, 32, 32])
>>> data = data.reshape(32, -1)
>>> data.shape
torch.Size([32, 1024])


That’s how I do it right now, but it strikes me as very inefficient as I have to do this each time my Dataloader hands me a new batch. I was hoping there is a preprocessing solution so that my Dataloader hands me the batches in the correct shape (no more reshaping during training).

You can also do the reshape in your dataloader if you think that is a significant source of inefficiency. However, this is unlikely because a reshape doesn’t actually change the data layout or require any copying of data. It just changes the strides for indexing.

Alright, then it might not be computationally inefficient, but it would help readability if I could tell the dataloader from the get go what to do. How do I do it in the Dataloader? The Dataloader is not a simple tensor.

You can add another transform function to your list of your transforms to do the reshape e.g.,

def reshape_transform():
    def fn(data):
        batch_size = data.size(0)
        return data.reshape(batch_size, -1)
    return fn

transfos = transforms.Compose([torchvision.transforms.Grayscale(),transforms.ToTensor(),transforms.Normalize((0.4749), (0.2382)), reshape_transform()])
1 Like

This should do it!
I will try it asap. Thank you!