Hi, I tried using customized collate_fn in DataLoader, just like:
def _collate_fn(batch): # batch_size == 2 # some codes processing batch # A = torch.tensor() # B = torch.tensor() # C = torch.tensor() # D = torch.tensor() # E = (str, str) # F = (sample_idx, sample_idx) return A, B, C, D, E, F data_loader = DataLoader(dataset, batch_size = 2, shuffle = True, collate_fn = _collate_fn) for idx, batch in enumerate(data_loader): # some codes
Here, collate_fn return a tuple with 6 elements, and traindataset has 100 elements. Then when I iter the dataloader, I cannot go through the whole traindataset in one epoch. And it stopped at iter_num == (100 / 6) + 1. But, it should have stopped at iter_num == 100 / 2 as the batch_size == 2. I cannot figure out why. Someone could help me?