Writing a function in pytorch that reads all the images in a folder

I want to use pytorch to write a simple function that reads images in a folder and then visualize them,this is the function I used :

import matplotlib.pyplot as plt
import os
import torch
filename = [name for name in os.listdir(image_path)]
batch_size = len(filename)
batch = torch.zeros(batch_size, 3, 240, 320, dtype=torch.uint8)
for i, file in enumerate(filename):

    batch[i] = torchvision.io.read_image(os.path.join(image_path, file))
    fig = plt.figure(figsize=(8, 2))
    for i in range(batch.shape[0]) 
        ax = fig.add_subplot(1, 3, i+1)
        ax.imshow(batch[i].permute(1, 2, 0))

I got an error “ValueError: num must be an integer with 1 <= num <= 3, not 4”
I suppose it’s because it only shows 3 images ,
But isn’t batch[0] suppose to have the value of the batch_size in my case 252 as I have a 252 20*240 RGB images ?
and when I try to print the i I got 0 why is that can anyone help me ?