I have created a dataloader which loads 12MB file in init() method and loads 2 images in getitem(). For this dataloader, loading time is slow(4s - 10s)
I measure data time as below:
data_loader = DataLoader(data, batch_size=16, num_workers=8, shuffle=True, drop_last=True)
end = time.time()
for data_iter_step, batch in enumerate(data_loader):
print(f"datatime: {(time.time()-end):.3f}s")
Dataloader class:
class Ho3dHandObjRelPopse(Dataset):
"""
ho3d dataset for rel-pose estimation
"""
def __init__(self, interf_cls, sampl_strat, ang_thmin, ang_thmax, split, num_sampl_frms=2):
self.sampl_strat = sampl_strat
self.ang_thmin = ang_thmin
self.ang_thmax = ang_thmax
self.datasetcls = interf_cls
self.num_sampl_frms = num_sampl_frms
self.split = split
self.datadir = self.datasetcls.datadir
self.all_seq_ids = os.listdir(f"{self.datadir}/{self.split}")
print(f'Split: {self.split} | No. of Seqs: {len(self.all_seq_ids)}')
# get all poses into one variable
start0 = time.time()
print(f'Loading {self.split} seqs poses...')
self.all_seq_obj_poses = load_pkl(osp.join(HO3D_CACHE_DIR, f"{self.split}_all_seq_poses.pkl")) # ~0.4-0.7s (12MB file)
print(f"self.all_seq_obj_poses Time: {(time.time()-start0):.3f}s")
def __getitem__(self, index):
start1 = time.time()
seqid = self.all_seq_ids[index]
seq_obj_poses = torch.from_numpy(np.array(self.all_seq_obj_poses[seqid]))
print(f"seq_obj_poses Time: {(time.time()-start1):.3f}s")
start2 = time.time()
rdiff_mat = torch.rad2deg(rotmat_geodesic_distance(seq_obj_poses[None, :, :3, :3], seq_obj_poses[:, None, :3, :3]))
# rdiff_mat = torch.rad2deg(rotmat_geodesic_distance(seq_obj_poses[:, :3, :3], seq_obj_poses[:, :3, :3].unsqueeze(1))) # must give same res
row_inds, col_inds = torch.where((rdiff_mat <= self.ang_thmax) & (rdiff_mat > self.ang_thmin))
# row_inds, col_inds = torch.where((rdiff_mat <= self.ang_thmax) & (rdiff_mat > self.ang_thmin) & (torch.logical_not(rdiff_mat.isnan())))
rand_ind = random.choice(range(len(row_inds)))
pair_inds = (row_inds[rand_ind], col_inds[rand_ind])
# print("pair_inds", pair_inds)
print(f"Pairs Sample Time: {(time.time()-start2):.3f}s") # ~0.04s
p1 = seq_obj_poses[pair_inds[0]]
p2 = seq_obj_poses[pair_inds[1]]
imgpth1 = osp.join(self.datadir, self.split, seqid, f"rgb/{pair_inds[0]:04}.jpg")
imgpth2 = osp.join(self.datadir, self.split, seqid, f"rgb/{pair_inds[1]:04}.jpg")
img1 = torch.from_numpy(self.get_img(imgpth1))
img2 = torch.from_numpy(self.get_img(imgpth2))
imgs = torch.stack([img1, img2])
poses = torch.stack([p1, p2])
print(f"__getitem__() Time: {(time.time()-start1):.3f}s") # ~0.07-0.09s
ret_dict = DottedDict(
{
'imgs': imgs,
'trans': poses[:, :3, 3],
'rots' : poses[:, :3, :3],
'pose': poses
}
)
return ret_dict
def __len__(self, ):
"len of all valid seqs"
return len(self.all_seq_ids)
def get_all_meta_pths(self, split):
if split == 'train':
pths = glob.glob(osp.join(self.datadir, 'train/*/meta/*.pkl'))
elif split == 'val':
pths = glob.glob(osp.join(self.datadir, 'evaluation/*/meta/*.pkl'))
elif split == 'all':
pths = glob.glob(osp.join(self.datadir, 'train/*/meta/*.pkl')) + glob.glob(osp.join(self.datadir, 'evaluation/*/meta/*.pkl'))
return pths
def get_img(
self, pth, n_retries: int = 5, retry_delay: float = 1.):
for _ in range(n_retries):
try:
return imread(pth)
except OSError as e:
sleep(retry_delay)
err = e
raise err