This part is essentially the resampling, the methods are extensions to a torch.utils.data Dataset.
Since I handle medical data I have to make sure to split test and validation after the different patients - in the code below, named “domain”
The images and ground-truth are stored in self.data
Basically the dataset stores raw images and annotation and I generate a coordinate grid.
The getitem method then samples the positions. Due to the resampling the domains at the different indices might change, so I have to update the Samplers.
def _sample_grid(self):
grid = []
domains = []
def parcellate(img, lbl, domain, index):
width, height = img.shape[:2]
stride = self.stride
size = self.size
off_x, off_y = np.random.randint(0, stride, size=(2,))
XX, YY = np.mgrid[off_x:width-size:stride, off_y:height-size:stride]
doms = domain*np.ones((XX.shape[0]*XX.shape[1],), dtype=np.int32)
indc = index*np.ones((XX.shape[0]*XX.shape[1],), dtype=np.int32)
return list(zip(indc, XX.flatten(), YY.flatten())), doms
for idx, item in enumerate(self.data, 0):
dom, img, lbl = item
gri, doms = parcellate(img, lbl, dom, idx)
assert len(gri) == len(doms)
grid.extend(gri)
domains.extend(doms)
return grid, domains
def resample_grid(self):
self.grid, self.domains = self._sample_grid()
def __getitem__(self, index):
ii, xx, yy = self.grid[index]
dom, img, lbl = self.data[ii]
img_p = img[xx:xx+self.size, yy:yy+self.size, :]
lbl_p = lbl[xx:xx+self.size, yy:yy+self.size, :].astype('int32')
sample = (img_p, lbl_p, ii, self.domains[index])
if self.transform:
sample = self.transform(sample)
return sample
After resampling the Samplers are generated by calling:
def get_datasamplers(dset, slides, batch_size):
indices = np.arange(len(dset))
train_idc = []
valid_idc = []
test_idc = []
for slide_id in np.concatenate(slides):
if slide_id in slides[0]:
train_idc.extend(indices[slide_id == dset.domains])
if slide_id in slides[1]:
valid_idc.extend(indices[slide_id == dset.domains])
if slide_id in slides[2]:
test_idc.extend(indices[slide_id == dset.domains])
train_smpl = SubsetRandomSampler(train_idc)
valid_smpl = SubsetRandomSampler(valid_idc)
test_smpl = SubsetRandomSampler(test_idc)
train_smpl_batch = BatchSampler(train_smpl, batch_size, False)
valid_smpl_batch = BatchSampler(valid_smpl, batch_size, False)
test_smpl_batch = BatchSampler(test_smpl, batch_size, False)
return train_smpl, train_smpl_batch, valid_smpl, valid_smpl_batch, test_smpl, test_smpl_batch