How to plot Nifty image format in pytorch

Hello
i am trying to plot Nifty format in pytorch after passing them to the loader but i am getting wrong image , code and image shown below

class Dataloder_img(data.Dataset):
    def __init__(self,root_dir,seg_dir,transforms ):
        self.root_dir = root_dir
        self.seg_dir = seg_dir
        self.transforms = transforms
        self.files = os.listdir(self.root_dir)
        self.lables = os.listdir(self.seg_dir)
        print(self.files)
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self,idx):
        img_name = self.files[idx]
        label_name = self.lables[idx]
        img = nib.load(os.path.join(self.root_dir,img_name)) #!Image.open(os.path.join(self.root_dir,img_name))
        #change to numpy
        img = np.array(img.dataobj)
        #change to PIL 
        img = Image.fromarray(img.astype('uint8'), 'RGB')
        
        print(img.size)
        
        label = nib.load(os.path.join(self.seg_dir,label_name))#!Image.open(os.path.join(self.seg_dir,label_name))
        #change to numpy
        label = np.array(label.dataobj)
        #change to PIL 
        label = Image.fromarray(label.astype('uint8'), 'RGB')
        
        print(label.size)
        
        if self.transforms:
            img = self.transforms(img)
            label = self.transforms(label)
            return img,label
        else:
            return img, label
full_dataset = Dataloder_img('/``imageTr',
                                     '/``/labelTr',tfms.Compose([tfms.RandomRotation(0),tfms.Resize((256,256)),tfms.ToTensor()
                                                            ]))#
                                   

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
train_loader = data.DataLoader(train_dataset,shuffle=False,batch_size=bs)
val_loader = data.DataLoader(val_dataset,shuffle=False,batch_size=bs)
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
test_img, test_lb = next(iter(full_dataset))
print(test_img[0].shape)
plt.imshow(test_lb[0])

download

it seem you are printing the dummy image in the data loader. But you can load the nifty file with many method i am showing one tham below

def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == 'a':
        previous_slice(ax)
    elif event.key == 'q':
        next_slice(ax)
    fig.canvas.draw()

def multi_slice_viewer(volume):
    fig, ax = plt.subplots()
    ax.volume = volume
    ax.index =60
    ax.imshow(volume[ax.index],cmap='gray')
    fig.canvas.mpl_connect('key_press_event', process_key)
data_dir='directory'                                             #image directory
img=nib.load(os.path.join(data_dir,'niftyfile_name.nii'))                           #loading the image
img_data=img.get_data()                                                     #accessing image array
#multi_slice_viewer(img_data)
#plt.show()
#plt.imshow(img_data[75])
print(img_data.shape)
plt.imshow(img_data[75], aspect=1)
#for img_data
plt.imshow(img_data[75])


#How to print slice of the 3D numpy array
a=img_data
a[10:,:,50]