i think the source for my slow net is the dataloader.
currently im just trying to overfit a network on one sample so i was suspicious it takes so long.
here is the dataset:
( just to be more clear, the dataset loads an existing torch tensor, then silences a small square inside and returns it as input and the original tensor is the target, image inpainting task).
class HCP_sanity(Dataset):
"""
extra dataset class to check stuff on one sample
"""
def __init__(self, args, dataset_path, norm_method, time_unit=5,
subjects=155, load=False):
self.root = str(dataset_path)
self.norm_method = norm_method
self.full_vol_dim = (64, 64, 32, 168) # slice, width, height, time
self.time_unit = time_unit
self.data_dir = os.path.join(self.root, 'hcp_clean_' + self.norm_method + '_normalize',
'TR_{}'.format(self.time_unit))
file_name = os.listdir(self.data_dir)[500:501] #only one file because overfitting
sample_path = os.path.join(self.data_dir, file_name[0])
self.data = torch.load(sample_path) #load data to memory (shape:(64,64,32,5))
self.slice = False
if args.model == 'UNET2D_go':
self.slice = True
def __len__(self):
return 1
def __getitem__(self, index):
if self.slice:
depth = 15
target = self.data[:,:,depth,0]
x = target.clone()
x[36:44,26:34] = 0
else:
target = self.data[:,:,:,0]
x = target.clone()
x[36:44,26:34,16:20] = 0
return x.unsqueeze(0), target.unsqueeze(0)
dataset = HCP_sanity()
generator = DataLoader(dataset,batch_size = 64,shuffle=False,num_workers=1)
a,b = iter(generator).next() # **this takes about 2 seconds!!!**
also, i guess since i use only one sample and the get_item function loads only one sample, the batch_size is ignored and the shape of the variable a is (1,1,64,64)