Dataloader splitting

How do I split images in dataloader with pytorch

unfold works on the specified dimension, so you would most likely have to increase the specified dimensions assuming the batch dimension is dim0 and you don’t want to unfold it.

In the code that I found on this forum, it seems that I am setting x to be only one image (28 by 28).

x = mnist_train[0][0][-1, :, :]

But I want to do make patches for each of the images, not just one (whereas ‘x’ currently in only one 28by28 image). I then want to flatten and put through a linear layer to get a matrix.
Someone suggested einops - but I’m not sure how to code this either.

Sorry, I’m new to deep learning.

Any help is appreciated.

What @ptrblick means is that you can use the same code you are using for all of the images, just taking into account the right dimensions.

So if you only have one image (like x in your example), you only have HxW = [28x28].

If you had 1 RGB image, then you have to include the Channel → CxHxW=[1x28x28].

Now if you have a batch of RGB images, this becomes → [BxCxHxW].

For the Dataset you are using, the shape is BxHxW=[60000x28x28], meaning you have 60000 images (no channel, becausse they are grayscale) of 28x28 pixels.

The first value you are passing to unfold defines in which dimension you will do the unfolding. So if you want to leave the Batch alone, and perform the same operation you just have to shift one to the right.

# Dim 0 is the Batch
x =, 7, 7).unfold(2, 7, 7)
# The new shape is
# torch.Size([60000, 4, 4, 7, 7])

If you now wanted to do this with many images with RGB, you also have to take this extra dimension into consideration

# Dim 0 is the Batch
# Dim 1 is the Channel
x =, 7, 7).unfold(3, 7, 7)
# The new shape is
# torch.Size([60000, CHANNEL, 4, 4, 7, 7])

You can compare to see that the results are the same

# I tried a couple of images to see if I get the same results
img = 900

mnist_train = torchvision.datasets.FashionMNIST( root="…/data", train=True, transform=transforms.Compose([transforms.ToTensor()]), download=True)

x = mnist_train[img][0][-1, :, :]
x = x.unfold(0, 7, 7).unfold(1, 7, 7)

y =, 7, 7).unfold(2, 7, 7)

print((y[img]/255 == x).all())
# Output:
# tensor(True)

If you want then to flatten the results, you can do something like:

y = y.flatten(start_dim=3)
# The new shape is
# torch.Size([60000, 4, 4, 49])

If you also want the patches to be flat you can use this:

y.reshape(-1, 16, 49)
# The new shape is
# torch.Size([60000, 16, 49])

Hope this helps :smile: