I am coding Dataloader for my own data. I return output as numpy but dataloader gives me torch.Tensor as the output. Don’t understand why.
from torch.utils import data
import torch
import nibabel as nib
class getdata(data.Dataset):
'''
Initializes a dataset for the network
Assumes that the data_dir has files named MRimages and CTimages that contain all the images
for all the patients in .hdr format.
'''
def __init__(self,data_dir,transform):
'Initialization'
self.data_dir = data_dir
self.transform = transform
self.list_IDs = np.arange(nib.load(self.data_dir+ '/MRimages.img').shape[2]) #list of all patients.
def __len__(self):
'Total no. of samples. Make sure that number of MR and CT samples are same.'
num = len(self.list_IDs) #total number of slices.
return num
def __getitem__(self,index): #index is patient ID
'Generate one sample of data'
ID = self.list_IDs[index]
MR = np.asarray(nib.load(self.data_dir+ '/MRimages.img').get_data()[:,:,ID])
CT = np.asarray(nib.load(self.data_dir+ '/CTimages.img').get_data()[:,:,ID])
sample = {'MR': MR,'CT': CT}
if self.transform:
sample = self.transform(sample)
return sample
Calling function:
data_dir = '...'#my data directory
a = np.arange(nib.load(data_dir+ '/MRimages.img').shape[2])
params = {'batch_size': 64,
'shuffle': True,
'num_workers': 2}
train_set = getdata(data_dir,transform=None)
train_gen = data.DataLoader(train_set,**params)
for s in train_gen:
print(type(s['MR']))
This gives me <class ‘torch.Tensor’> for every batch.
I want to make the self.tranform as class that works on numpy matrix and not torch matrix.
Any suggestions?
The reason for your DataLoader returning torch.tensors even though are are returning numpy arrays is most likely due to the usage of the default_collate method. You can see in the line of code I’m referring to how numpy arrays are wrapped in torch.tensors.
If you check the type of train_set[0] you should get a numpy array, which means that the transform in __getitem__ is actually working on numpy arrays. The DataLoader just makes your life a bit easier as you probably want to use torch.tensors in your training loop.
That makes sense. So my numpy arrays are contained in a torch batch, so I can do transform on numpy.
Another question. So I am assuming that since my batch size is 64, when getitem is called, the list ID should have 64 random indices that are used to get the images.
But when I do MR.transpose((2,1,0)), I get ValueError: axes don't match array error
And then MR.transpose((1,0)) works fine. So it seems that it’s getting ID one by one.
I want to do my self.transpose on 3D arrays and not on 2D arrays.
Is there something I am doing wrong? Also my pytorch version is 0.4.1.post2
The __getitem__ method uses an index to get a single samples not a batch, i.e. the batch dimension of your data is missing in __getitem__.
Usually this makes developing of a custom Dataset really easy, as you just have to think about how to get a single samples of data. The DataLoader yields a complete batch of samples and provides some additional functionalities like shuffling the dataset or using multiple workers.
It seems in your use case you would like to load a whole bunch of slices of your MRI images.
Could you print the shapes of MR and CT in __getitem__ since I’m currently not sure how your indexing works.
I assume the shape you’ve printed is the shape of the batch from the DataLoader.
If that’s the case your MR and CT data have the shape [200, 200] in __getitem__.
Based on your description it seems you would like to apply the same transformation on the whole batch instead of each single image.
If that’s the case one approach would be to create an own sampler and provide a list of random indices to your Dataset.
In the __getitem__ you would get a list of indices of the length batch_size, could load all images one by one and apply the same transformation on them.
In the training loop you would get an additional batch dimension and can just squeeze it.
Here is a small example:
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
x = torch.stack([self.data[i] for i in index])
return x
def __len__(self):
return len(self.data)
class RandomBatchSampler(torch.utils.data.sampler.Sampler):
def __init__(self, data_source, batch_size):
self.data_source = data_source
self.batch_size = batch_size
def __iter__(self):
rand_idx = torch.randperm(len(self.data_source)).tolist()
data_iter = iter([rand_idx[i:i+self.batch_size] for i in range(0, len(rand_idx), self.batch_size)])
return data_iter
def __len__(self):
return len(self.data_source)//self.batch_size
data = torch.randn(100, 3, 24, 24)
dataset = MyDataset(data)
batch_size = 64
sampler = RandomBatchSampler(data, batch_size=batch_size)
loader = DataLoader(
dataset,
batch_size=1,
num_workers=2,
sampler=sampler
)
for x in loader:
x.squeeze_(0)
print(x.shape)