ESRGAN Implementation-: RuntimeError: number of dims don't match in permute

I was working on ESRGAN implementation and while running the code, the following error occured:-

RuntimeError: number of dims don’t match in permute

The code is given below:- `import torch
from torchvision import transforms
from PIL import Image
from esrgan_model import ESRGAN
print(“imports done”)
#load the pretrained ESRGAN model
model = ESRGAN()
print(‘model imported’)
#define the image preprocessing steps
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
print(“preprocessing steps defined”)
#load the image
img = Image.open(“example.jpg”)
print(“loading the image”)

#preprocess the image
img = transform(img).unsqueeze(0)
print(‘Preprocessing the image…’)

#run the image through the model
output = model(img)
print(‘Running the image through the model…’)

#save the output image
print(output.shape)
output = output.squeeze()
print(output.shape)
output = output.permute(1, 2, 0)
output = (output * 0.5) + 0.5
output = transforms.ToPILImage()(output)
output.save(“output.jpg”)
`
The shape before and after squeeze is

torch.Size([1, 1, 28, 28])
torch.Size([28, 28])

If I change the dims, it gives the following error:-

IndexError: Dimension out of range (expected to be in range of [-2, 1],

How can i resolve this error?

The squeeze call removes dim0 and dim1 since you are not specifying any dimension in this call. The next permute(1, 2, 0) call expects to use 3 dimensions and thus fails.
In case you want to squeeze dim1 only, you should use output = output.squeeze(1) instead.

On doing so, I get the following error:-

Traceback (most recent call last):
File “D:/pythonProject/main.py”, line 34, in
output = transforms.ToPILImage()(output)
File “D:\Venvs\lib\site-packages\torchvision\transforms\transforms.py”, line 179, in call
return F.to_pil_image(pic, self.mode)
File “D:\Venvs\lib\site-packages\torchvision\transforms\functional.py”, line 227, in to_pil_image
raise ValueError(‘pic should not have > 4 channels. Got {} channels.’.format(pic.shape[-3]))
ValueError: pic should not have > 4 channels. Got 28 channels.

Although the dims reduced to 3 values as follows:-

torch.Size([1, 28, 28])

But the error above is causing problem in the code.

ToPILImage does not expect a channel dimension for grayscale images, so you would need to squeeze the channel dimension after the permute call or you could keep squeezing both dimensions and fix the permute call instead:

output = transforms.ToPILImage()(output.squeeze(2))