Dataloading is too slow

I have created a dataloader which loads 12MB file in init() method and loads 2 images in getitem(). For this dataloader, loading time is slow(4s - 10s)

I measure data time as below:

data_loader = DataLoader(data, batch_size=16, num_workers=8, shuffle=True, drop_last=True)
end = time.time()
for data_iter_step, batch in enumerate(data_loader):
    print(f"datatime: {(time.time()-end):.3f}s")

Dataloader class:

class Ho3dHandObjRelPopse(Dataset):
    """
    ho3d dataset for rel-pose estimation
    """
    def __init__(self, interf_cls, sampl_strat, ang_thmin, ang_thmax, split, num_sampl_frms=2):
        self.sampl_strat = sampl_strat 
        self.ang_thmin = ang_thmin 
        self.ang_thmax = ang_thmax
        self.datasetcls = interf_cls
        self.num_sampl_frms = num_sampl_frms
        self.split = split
        self.datadir = self.datasetcls.datadir

        self.all_seq_ids = os.listdir(f"{self.datadir}/{self.split}") 
        print(f'Split: {self.split} | No. of Seqs: {len(self.all_seq_ids)}')

        # get all poses into one variable
        start0 = time.time()
        print(f'Loading {self.split} seqs poses...')
        self.all_seq_obj_poses = load_pkl(osp.join(HO3D_CACHE_DIR, f"{self.split}_all_seq_poses.pkl"))  # ~0.4-0.7s (12MB file)
        print(f"self.all_seq_obj_poses Time: {(time.time()-start0):.3f}s")

    def __getitem__(self, index):
        start1 = time.time()
        seqid = self.all_seq_ids[index]
        seq_obj_poses = torch.from_numpy(np.array(self.all_seq_obj_poses[seqid]))
        print(f"seq_obj_poses Time: {(time.time()-start1):.3f}s")

        start2 = time.time()
        rdiff_mat = torch.rad2deg(rotmat_geodesic_distance(seq_obj_poses[None, :, :3, :3], seq_obj_poses[:, None, :3, :3]))
        # rdiff_mat = torch.rad2deg(rotmat_geodesic_distance(seq_obj_poses[:, :3, :3], seq_obj_poses[:, :3, :3].unsqueeze(1))) # must give same res

        row_inds, col_inds = torch.where((rdiff_mat <= self.ang_thmax) & (rdiff_mat > self.ang_thmin))
        # row_inds, col_inds = torch.where((rdiff_mat <= self.ang_thmax) & (rdiff_mat > self.ang_thmin) & (torch.logical_not(rdiff_mat.isnan())))

        rand_ind = random.choice(range(len(row_inds)))
        pair_inds = (row_inds[rand_ind], col_inds[rand_ind])
        # print("pair_inds", pair_inds)
        print(f"Pairs Sample Time: {(time.time()-start2):.3f}s") # ~0.04s

        p1 = seq_obj_poses[pair_inds[0]]
        p2 = seq_obj_poses[pair_inds[1]]
        
        imgpth1 = osp.join(self.datadir, self.split, seqid, f"rgb/{pair_inds[0]:04}.jpg")
        imgpth2 = osp.join(self.datadir, self.split, seqid, f"rgb/{pair_inds[1]:04}.jpg")
        img1 = torch.from_numpy(self.get_img(imgpth1))
        img2 = torch.from_numpy(self.get_img(imgpth2))
        imgs = torch.stack([img1, img2])
        poses = torch.stack([p1, p2])

        print(f"__getitem__() Time: {(time.time()-start1):.3f}s") # ~0.07-0.09s
        
        ret_dict = DottedDict(
            {
                'imgs': imgs,
                'trans': poses[:, :3, 3],
                'rots' : poses[:, :3, :3],
                'pose': poses
            }
        )
        return ret_dict 
    
    def __len__(self, ):
        "len of all valid seqs"
        return len(self.all_seq_ids)
    
    def get_all_meta_pths(self, split):
        if split == 'train':
            pths = glob.glob(osp.join(self.datadir, 'train/*/meta/*.pkl')) 
        elif split == 'val':
            pths = glob.glob(osp.join(self.datadir, 'evaluation/*/meta/*.pkl'))
        elif split == 'all':
            pths = glob.glob(osp.join(self.datadir, 'train/*/meta/*.pkl')) + glob.glob(osp.join(self.datadir, 'evaluation/*/meta/*.pkl'))
        return pths
    
    def get_img(
        self, pth, n_retries: int = 5, retry_delay: float = 1.):
        for _ in range(n_retries):
            try:
                return imread(pth)
            except OSError as e:
                sleep(retry_delay)
                err = e
        raise err