View the images by slicing NCHW to [n0, C, h0, W0]

Here I’m trying to view my images after performing some DataLoader operation.

class Classification(Dataset):
    def __init__(self, df,  length, transform=None):
        self.df = df
        self.data_len = len(self.df)
        self.len = length
        self.transform = transform
        

    def __getitem__(self, index):
        data_idx = index % self.data_len
        X = Image.open(self.df['file_path'][data_idx])
        y = torch.tensor(int(self.df['class_name'][data_idx]))
        if self.transform:
            X = self.transform(X)
        
        
        return X, y
        
    def __len__(self):
        return self.len
		
length=2
training_set = Classification(df, length, transform=train_transform)
train_loader = DataLoader(training_set, batch_size=5)

for batch_idx, (inputs, labels) in enumerate(train_loader):   
    print(inputs.shape)     #torch.Size([2, 3, 224, 224])
    inputs = np.squeeze(inputs, axis=0) 
    inputs = inputs.permute(1, 2, 0)
    plt.figure()
    plt.imshow(inputs.numpy())
    plt.show()

The above code fails because np.squeeze does not work with dim=0, as the value is 2. How can I slice 1 image at a time and plot using for loop ?
I want to check all images. If length=25, then i would like to view 25 images.

Instead of squeeze, which tries to remove a dimension with size=1, you could index the images in the batch in a loop:

for batch_idx, (inputs, labels) in enumerate(train_loader):   
    print(inputs.shape)     #torch.Size([2, 3, 224, 224])
    for idx in range(inputs.size(0)):
        input = inputs[idx]
        input = input.permute(1, 2, 0)
        plt.figure()
        plt.imshow(input.numpy())
        plt.show()
1 Like