Image not displaying properly

I have the following code portion:

dataset = trainDataset()
train_loader = DataLoader(dataset,batch_size=1,shuffle=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

images = []
image_labels = []

for i, data in enumerate(train_loader,0):
    inputs, labels = data
    inputs, labels =,
    inputs, labels = inputs.float(), labels.float()

image = images[7]
image = image.numpy()
image = image.reshape(416,416,3)
img = Image.fromarray(image,'RGB')

The issue is that the image doesn’t display properly. For instance, the dataset I have contains images of cats and dogs. But, the image displayed looks as shown below. Why is that?

enter image description here

Assuming the original image is in the channels-first memory layout, this line of code:

image = image.reshape(416,416,3)

would interleave the data and would thus result in the posted image.
To permute the axes, use image = image.permute(1, 2, 0).