Hi, I tried using customized collate_fn in DataLoader, just like:
def _collate_fn(batch):
# batch_size == 2
# some codes processing batch
# A = torch.tensor()
# B = torch.tensor()
# C = torch.tensor()
# D = torch.tensor()
# E = (str, str)
# F = (sample_idx, sample_idx)
return A, B, C, D, E, F
data_loader = DataLoader(dataset,
batch_size = 2, shuffle = True, collate_fn = _collate_fn)
for idx, batch in enumerate(data_loader):
# some codes
Here, collate_fn return a tuple with 6 elements, and traindataset has 100 elements. Then when I iter the dataloader, I cannot go through the whole traindataset in one epoch. And it stopped at iter_num == (100 / 6) + 1. But, it should have stopped at iter_num == 100 / 2 as the batch_size == 2. I cannot figure out why. Someone could help me?