How can I solve the error: TypeError: Invalid shape (60, 60, 8) for image data

Hi, I’m new to pytorch. I’m trying to create a DCGAN project. I used the entire official pytorch tutorial (DCGAN Tutorial — PyTorch Tutorials 1.10.1+cu102 documentation) as a base.

I have a numpy array that is the combination of eight arrays, which given a shape (60,60,8) this shape is special

lista2 = [0, 60, 120, 180, 240, 300, 360, 420]
total = []
for i in lista2:
   N1 = intesity[0:60, i:i+60]
   total.append(N1)
   N2 = intesity[60:120, i:i+60]
   total.append(N2)
   N3 = intesity[120:180, i:i+60]
   total.append(N3)
   N4 = intesity[180:240, i:i+60]
   total.append(N4)
   N5 = intesity[240:300, i:i+60]
   total.append(N5)
   N6 = intesity[300:360, i:i+60]
   total.append(N6)
   N7 = intesity[360:420, i:i+60]
   total.append(N7)
   N8 = intesity[420:480, i:i+60]
   total.append(N8)

total = np.reshape(total, (64, 60,60,8))
total  -= total.min()
total  /= total.max()
total = np.asarray(total)
print(np.shape(total)
(64, 60, 60, 8)

as you can see there are 64 elements in that array, there are 64 training images (very few for now), this array is converted to a tensor and then to a pytorch dataset

tensor_c = torch.tensor(total)

creating a dataset and a dataloader I get the following error, when trying to graph the training images of this DCGAN

dataset = TensorDataset(tensor_c) # create your datset
dataloader = DataLoader(dataset) # create your dataloader

real_batch = next(iter(dataloader))
plt.figure(figsize=(16,16))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=0, normalize=True).cpu(),(1,2,0)))
dataset_size = len(dataloader.dataset)
dataset_size
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-42-5ba2d666ef25> in <module>()
     10 plt.axis("off")
     11 plt.title("Training Images")
---> 12 plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=0, normalize=True).cpu(),(1,2,0)))
     13 dataset_size = len(dataloader.dataset)
     14 dataset_size

5 frames
/usr/local/lib/python3.7/dist-packages/matplotlib/image.py in set_data(self, A)
    697                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
    698             raise TypeError("Invalid shape {} for image data"
--> 699                             .format(self._A.shape))
    700 
    701         if self._A.ndim == 3:

TypeError: Invalid shape (60, 60, 8) for image data

I am too new to Pytorch I would like to know how I can solve this problem

You are trying to display the image, but that requires either 1 (for BW) or 3 (for RGB) channels; Your output has 8 channels (which I am not sure I understand what represent). When displaying an image, for RGB (3 channels), the resulting pixel-wise value is the mean of the values in the 3 channels. Perhaps you can compute the mean of those values yourself, with something like

mean_image = torch.mean(eight_channel_image, dim=-1)

# where eight_channel_image is what you have in plt.imshow(...), but converted to a torch.Tensor I suppose

What exactly are you expecting to see when plotting this?

Hi, thanks for answering that array has 8 channels because it is the combination of 8 images (array)

you should get something similar to this

descarga

The imshow method accepts tensors whose last dimension is either 1 or 3. 8 will not work. As you said it is a combination of 8 images (for 8 channels), I assume each image is black and white (1 channel). Getting from 8 channels to 1 channel can be done by doing the mean/average of the values for each of the channels (which is what np.mean or torch.mean do), and then pass the result to the imshow.