Hi All, I am trying to plot a channel of a 3D tensor image from my test data. I used the Dataloader function to load my images so images are tensors. The images are in the shape torch.Size([1, 10, 60200])
but I have reshaped them into have reshaped the images torch.Size([1, 10, 200, 301])
where test batch size=1, channels=10, height=200, width=100
. I want to plot one channel of this 3D image so I reshape to torch.Size([200, 301, 10]) by moving the channels to the last dimension. However when I run I get TypeError: Invalid dimensions for image data
. Any suggestions to resolve this issue and plot a channel of my 3D image are welcomed. My plot function and data loading code are below:
test = data_utils.TensorDataset(torch.from_numpy(test_set),torch.from_numpy(label_set))
test_loader = data_utils.DataLoader(test,batch_size=TestBatchSize,shuffle=True)
for i in range(TestBatchSize):
image, label = iter(test_loader).next() #shape of images = [1, 10, 60200]
image = image.view(TestBatchSize,Inchannels,data_dsp_dim[0],data_dsp_dim[1]) #reshaped to [1,10,200,301]
for k in range(TestBatchSize):
seis = images.view(data_dsp_dim[0], data_dsp_dim[1], Inchannels) #reshaped to [200, 301, 10] --> moved channel to the last dimension
### I want to plot just one channel out of the 10 channels in the data####
#Below is my plot function
def display_seismic(seis, font2, font3, SavePath):
fig, ax = plt.subplots(figsize=(4, 4))
seis = ax.imshow(seis, cmap='gray', aspect='auto', vmin=0.06, vmax=0.06)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right",size="5%",pad=0.05)
plt.colorbar(seis,ax=ax,cax=cax).set_label('Amplitude (m/s)')
plt.tick_params(labelsize=6)
for label in ax.get_xticklabels()+ax.get_yticklabels():
label.set_fontsize(6)
plt.savefig(SavePath+'Seismic',transparent=True)
plt.close()