Visualize torchvision data

I can load svhn dataset, but how can I plot some images?
I tried to plot with cv2, but the data is torch.tensor type so I failed.
Anyone can help ?

from torchvision import datasets, transforms
from torch.utils.data import DataLoader


def svhn_train_loader():
    pre_process = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=(0.5, 0.5, 0.5),
                                          std=(0.5, 0.5, 0.5))])

 
    svhn_dataset = datasets.SVHN(root='svhn',
                                 transform=pre_process,
                                 download=True)

    svhn_data_loader = DataLoader(
        dataset=svhn_dataset,
        batch_size=params.batch_size,
        shuffle=True)

    return svhn_data_loader

You can get the numpy array using data.numpy().
Then you can plot this array with OpenCV, matplotlib etc.

@ptrblck Thanks for quick reply. I tried data.numpy()

import matplotlib.pyplot as plt
for batch_idx, (inputs, labels) in enumerate(svhn_train_loader):
     plt.figure()
     plt.imshow(inputs.numpy())
     plt.show()

TypeError: Invalid dimensions for image data
How can I fix this error?

matplotlib needs the channel in dim2.
Permute your image before passing it to imshow.
inputs = inputs.permute(1, 2, 0)

2 Likes