You will need a custom Dataset.
How I use the custom for my use case (I assume you use nibabel to handle nifti files), I preload all the file handle with nibabel.
During the iteration you could get the data with the np.asarray(nii_image.dataobj)
. If you use get_data()
you should specify the caching by default it keeps every loaded image in cache this could overload you system memory by loading to many image.
class MyCustomDataset(Dataset):
def __init__(self, path):
# load all nii handle in a list
self.images_list = [nib.load(image_path) for image_path in path]
def __len__(self):
return len(self.images_list)
def __getitem__(self, idx):
nii_image = self.images_list[idx]
data = torch.from_numpy(np.asarray(nii_image.dataobj))
# find how to retrieve the target
return data, target
You will need a ways to associate the target of your __getitem__
function, this part depend on how you will handle the data/directory, etc.