Hello, I am trying to convert the images form MNIST to RGB using torch.transforms, my code is as follows:
IMG_SIZE = 32
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 256
TRANSFORM = transforms.Compose([
#transforms.Lambda(lambda x: x.convert('RGB')),
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=TRANSFORM)
usps_train = datasets.USPS(root='./data', train=True, download=True, transform=TRANSFORM)
mnist_loader = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)
usps_loader = torch.utils.data.DataLoader(usps_train, batch_size=BATCH_SIZE, shuffle=True)
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=TRANSFORM)
usps_test = datasets.USPS(root='./data', train=False, download=True, transform=TRANSFORM)
mnist_test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=BATCH_SIZE)
usps_test_loader = torch.utils.data.DataLoader(usps_test, batch_size=BATCH_SIZE)
X_test_fixed = next(iter(mnist_test_loader))[0]
Y_test_fixed = next(iter(usps_test_loader))[0]
The code works fine and I get RGB images, but when I try to visualize them I get these weird results.
To visualize the image I used
plt.imshow(X_test_fixed[1].reshape(32,32,3))
plt.show()
I would appreciate any help, I even tried to use the convert function from PIL and the result is the same.