for medical images format Nifty ( nii.gz) how to load the image using data loader in pytorch and then plot them, my code blow try to do but seems there is something wrong with the numby shape since the image show as lines instead of normal medical images
bs = 2
num_epochs = 100
learning_rate = 1e-3
mom = 0.9
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import torchvision
import torchvision.transforms as tfms
import torch.utils.data as data
import matplotlib.pyplot as plt
import torch.nn.functional as F
import scipy as sc
import os
import PIL
import PIL.Image as Image
import seaborn as sns
import warnings
import nibabel as nib#http://nipy.org/nibabel/gettingstarted.html
class Dataloder_img(data.Dataset):
def __init__(self,root_dir,seg_dir,transforms ):
self.root_dir = root_dir
self.seg_dir = seg_dir
self.transforms = transforms
self.files = os.listdir(self.root_dir)
self.lables = os.listdir(self.seg_dir)
print(self.files)
def __len__(self):
return len(self.files)
def __getitem__(self,idx):
img_name = self.files[idx]
label_name = self.lables[idx]
img = nib.load(os.path.join(self.root_dir,img_name)) #!Image.open(os.path.join(self.root_dir,img_name))
#change to numpy
img = np.array(img.dataobj)
#change to PIL
img = Image.fromarray(img.astype('uint8'), 'RGB')
print(img.size)
label = nib.load(os.path.join(self.seg_dir,label_name))#!Image.open(os.path.join(self.seg_dir,label_name))
#change to numpy
label = np.array(label.dataobj)
#change to PIL
label = Image.fromarray(label.astype('uint8'), 'RGB')
print(label.size)
if self.transforms:
img = self.transforms(img)
label = self.transforms(label)
return img,label
else:
return img, label
full_dataset = Dataloder_img(' image ',
' labels ',tfms.Compose([tfms.RandomRotation(180),tfms.ToTensor()
]))#
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
train_loader = data.DataLoader(train_dataset,shuffle=False,batch_size=bs)
val_loader = data.DataLoader(val_dataset,shuffle=False,batch_size=bs)
test_img, test_lb = next(iter(full_dataset))
print(test_img[0].shape)
plt.imshow(test_img[0])
plt.show()