Extra dimension in data loader?

I am using the following code to load the MNIST dataset:

batch_size = 64
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=batch_size)

If I try to load one batch:

for data, target in train_loader:
    print(data.shape)
    break

this prints:

torch.Size([64, 1, 28, 28])

What I don’t understand is why the size is 64 x 1 x 28 x 28 instead of just 64 x 28 x 28? What is this extra dimension of length 1?

dim1 represents the channel dimension. Since MNIST uses grayscale images, the channel dim has a value of 1.
nn.Conv2d and other 2-dimensional layers expect an input of [batch_size, channels, height, width].

2 Likes