TypeError: Invalid dimensions for image data

Hi there. I have a problem.
I want to show my images for the purpose of knowing precision and recall.
However, my images can’t be loaded by imshow.
I think images and imshow are mismatch.
so,I want to fix my problem.

dataiter = iter(testloader)
for data in dataiter:
images, labels = data
print (images.shape)
imshow(torchvision.utils.make_grid(images, nrow=5))
images, labels = images.to(device), labels.to(device)

outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
print(predicted)

print('Correct : ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
print('Predcit  : ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
print(predicted==labels)

torch.Size([4, 3, 255, 255])


TypeError Traceback (most recent call last)
in ()
3 images, labels = data
4 print (images.shape)
----> 5 imshow(torchvision.utils.make_grid(images, nrow=5))
6
7 # images, labels = images.to(device), labels.to(device)
/opt/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, hold, data, **kwargs)
3203 filternorm=filternorm, filterrad=filterrad,
3204 imlim=imlim, resample=resample, url=url, data=data,
-> 3205 **kwargs)
3206 finally:
3207 ax._hold = washold
/opt/anaconda3/lib/python3.6/site-packages/matplotlib/init.py in inner(ax, *args, **kwargs)
1853 “the Matplotlib list!)” % (label_namer, func.name),
1854 RuntimeWarning, stacklevel=2)
-> 1855 return func(ax, *args, **kwargs)
1856
1857 inner.doc = _add_data_doc(inner.doc,
/opt/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
5485 resample=resample, **kwargs)
5486
-> 5487 im.set_data(X)
5488 im.set_alpha(alpha)
5489 if im.get_clip_path() is None:
/opt/anaconda3/lib/python3.6/site-packages/matplotlib/image.py in set_data(self, A)
651 if not (self._A.ndim == 2
652 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
–> 653 raise TypeError(“Invalid dimensions for image data”)
654
655 if self._A.ndim == 3:
TypeError: Invalid dimensions for image data

You could try to permute the image tensor so that the channels are stored in dim2:

images = torch.randn(4, 3, 255, 255)
plt.imshow(torchvision.utils.make_grid(images, nrow=5).permute(1, 2, 0))
3 Likes

Thank you ,your advice. my code is working. but… I have a new problem.
When I used for your advice, the code outputs

Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).

and…images are noise like snow noise.
in addition, this is my channels.

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(166060, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x))) 
  
    x = x.view(x.size(0), 16 * 60 * 60)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)


    return x

You could avoid the clipping by normalizing the images first:

images -= images.min()
images /= images.max()

I just created the images using gaussian noise (torch.randn).
You should use your images tensor instead or do you see random noise there too?

2 Likes

Hi,

sorry, the data3 is the torch size of (64,21,21,11), I want to see the middle slice in grid. I am using the current code it give me error. Would you please tell me what is wrong? the fake33 has a (64,21,21) size.
Error is : (TypeError(“Invalid dimensions for image data”)
TypeError: Invalid dimensions for image data)

            fake33=data3[:,:,:,6]
            plt.close("all")
            plt.figure()
            plt.figure(figsize=(8,8))
            plt.axis("off")
            plt.title("Fake Images")
            plt.imshow(np.transpose(vutils.make_grid(fake33.detach().to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
            plt.savefig(os.path.join(root_dirDurringTraining13+'Eoch-'+str(epoch)+'Seed='+str(manualSeed))+'fakes.jpg')

I assume that you would like to visualize each of the 64 slices as a grayscale image?
If that’s the case, this code should work:

imgs = torch.randn(64, 21, 21)
imgs = torchvision.utils.make_grid(imgs.unsqueeze(1))
print(imgs.shape)
> torch.Size([3, 186, 186])

imgs = imgs.permute(1, 2, 0).numpy()
plt.imshow(imgs)
2 Likes