import torch
import torchvision
# !pip install torchvision
from torchvision import datasets, transforms
mnist2 = torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=transforms.Compose([
torchvision.transforms.Pad(18, fill=0, padding_mode='constant'),
transforms.ToTensor()
]))
The following code block gives [28,28] for img2 shape.
import matplotlib.pyplot as plt
img2= mnist2.data[0]
print(img2.shape)
plt.imshow(img2)
plt.show()
The following one gives [1, 64, 64] for img shape. Why are they different?
img, digit = next(iter(mnist2))
print(img.shape)
plt.imshow(img.squeeze())
plt.show()