"TypeError: Invalid dimensions for image data" when plotting a channel in a 3D tensor image

Hi All, I am trying to plot a channel of a 3D tensor image from my test data. I used the Dataloader function to load my images so images are tensors. The images are in the shape torch.Size([1, 10, 60200]) but I have reshaped them into have reshaped the images torch.Size([1, 10, 200, 301]) where test batch size=1, channels=10, height=200, width=100. I want to plot one channel of this 3D image so I reshape to torch.Size([200, 301, 10]) by moving the channels to the last dimension. However when I run I get TypeError: Invalid dimensions for image data. Any suggestions to resolve this issue and plot a channel of my 3D image are welcomed. My plot function and data loading code are below:

test = data_utils.TensorDataset(torch.from_numpy(test_set),torch.from_numpy(label_set))
test_loader = data_utils.DataLoader(test,batch_size=TestBatchSize,shuffle=True)
for i in range(TestBatchSize): 
    image, label = iter(test_loader).next()  #shape of images = [1, 10, 60200]
    image = image.view(TestBatchSize,Inchannels,data_dsp_dim[0],data_dsp_dim[1]) #reshaped to [1,10,200,301]

    for k in range(TestBatchSize):
        seis = images.view(data_dsp_dim[0], data_dsp_dim[1], Inchannels)  #reshaped to [200, 301, 10] --> moved channel to the last dimension

### I want to plot just one channel out of the 10 channels in the data####

#Below is my plot function
def display_seismic(seis, font2, font3, SavePath):
        fig, ax = plt.subplots(figsize=(4, 4))
        seis = ax.imshow(seis, cmap='gray', aspect='auto', vmin=0.06, vmax=0.06)
        divider = make_axes_locatable(ax)
        cax    = divider.append_axes("right",size="5%",pad=0.05)
        plt.colorbar(seis,ax=ax,cax=cax).set_label('Amplitude (m/s)')
        for label in  ax.get_xticklabels()+ax.get_yticklabels():

Here you are not doing what you say you are doing. For this, you should use either permute, or even better something like ToPILImage to take care of this for you.

As you said, you only want to plot ONE channel, so you need to select it. Right now it looks as if you are trying to plot all 10 channels and pyplot does not know what to do with 10 channels.