Hey @ptrblck,
I wrote some of the Dataset
like below:
class SegSet(data.Dataset):
def __init__(self, subdict, num_labels):
"""
:param subdict: a dictionary of subject and label paths
:param num_labels: number of segmentation labels (9)
"""
self.subdict = subdict
self.img_subs = subdict['img_subs']
self.img_files = subdict['img_files']
if checkKey(subdict, 'seg_subs'): # This function checks if segmentation label available or not
self.seg_subs = subdict['seg_subs']
self.seg_files = subdict['seg_files']
else:
self.seg_subs = None
self.seg_files = None
self.num_labels = num_labels
def __len__(self):
return len(self.img_subs)
def __getitem__(self, index):
num_labels = self.num_labels
sub_name = self.img_subs[index]
img_file = self.img_files[index]
img_3d = nib.load(img_file)
img = img_3d.get_data()
img = (img - img.min())/(img.max()-img.min())
img = img*255.0
seg_file = self.seg_files[index]
seg_3d = nib.load(seg_file)
seg = seg_3d.get_data()
imgp, segp = generate_patch_32_3(img, seg)
for i in range(1,num_labels):
for j in range(len(imgp)):
seg_one = segp == labels[i] #labels = labels number list e.g. [0, 1, 2 ,10, 56 ...]
segp[j, i, :, :, :] = seg_one[0:segp.shape[0], 0:segp.shape[1], 0:segp.shape[2]]
segp[j, 0, :, :, :] = segp[j, 0, :, :, :] - segp[j, i, :, :, :]
# print("Here")
imgp = imgp.astype('float32')
segp = segp.astype('float32')
return imgp, segp, sub_name
The generate_patch_32_3
function simply generates 3D patches from 180x256x256 img and segmentation img to 192x16x64x64 and 192x8x32x32 correspondingly.
The problem is after having paired source and target from SegSet
I don’t understand how to load them to the 'Dataloaderand also keep loading the next image for
SegSet`.
train_set = SegSet(train_dict, num_labels=9)
print(len(train_set))
x, y, z = next(iter(train_set))
print(x.shape, '\n', y.shape, '\n', z)
The output is:
2
(192, 1, 16, 64, 64)
(192, 9, 8, 32, 32)
001_MR2std
So the output of the Dataset
class is multiple (192) source and targets. When I pass it to Dataloader
:
train_loader = data.DataLoader(train_set, batch_size=16, shuffle=False, num_workers=1)
print(len(train_loader))
x, y, z = next(iter(train_loader))
print(x.shape, '\n', y.shape, '\n', z)
The output is:
1
torch.Size([2, 192, 1, 16, 64, 64])
torch.Size([2, 192, 9, 8, 32, 32])
('001_MR2std', '002_MR2std')
It doesn’t change even if I change the batch_size
to 8 or 80.
So if I feed the x and y to the network, is it taking the whole 192 patches at once? what is the significance of batches here?