Cannot have access to all data when enumerate the dataloader

I defined a custom Dataset and a custom Dataloader, and I want to access all the batches using for i,batch in enumerate(loader). But this for loop gives me different number of batches in every epoch, and all of them are far smaller then the actual number batches (which equals to number_of_samples/batch_size).

Here is how I define my dataset and dataloader:



class UsptoDataset(Dataset):
    def __init__(self, csv_file):
        df = pd.read_csv(csv_file)
        self.rea_trees = df['reactants_trees'].to_numpy()
        self.syn_trees = df['synthons_trees'].to_numpy()
        self.syn_smiles = df['synthons'].to_numpy()
        self.product_smiles = df['product'].to_numpy()

    def __len__(self):
        return len(self.rea_trees)

    def __getitem__(self, item):
        rea_tree = self.rea_trees[item]
        syn_tree = self.syn_trees[item]
        syn_smile = self.syn_smiles[item]
        pro_smile = self.product_smiles[item]
        # omit the snippet used to process the data here, which gives us the variables used in the return statement.
        return {'input_words': input_words,
                'input_chars': input_chars,
                'syn_tree_indices': syn_tree_indices,
                'syn_rule_nl_left': syn_rule_nl_left,
                'syn_rule_nl_right': syn_rule_nl_right,
                'rea_tree_indices': rea_tree_indices,
                'rea_rule_nl_left': rea_rule_nl_left,
                'rea_rule_nl_right': rea_rule_nl_right,
                'class_mask': class_mask,
                'query_paths': query_paths,
                'labels': labels,
                'parent_matrix': parent_matrix,
                'syn_parent_matrix': syn_parent_matrix,
                'path_lens': path_lens,
                'syn_path_lens': syn_path_lens}

    @staticmethod
    def collate_fn(batch):
        input_words = torch.tensor(np.stack([_['input_words'] for _ in batch], axis=0), dtype=torch.long)
        input_chars = torch.tensor(np.stack([_['input_chars'] for _ in batch], axis=0), dtype=torch.long)
        syn_tree_indices = torch.tensor(np.stack([_['syn_tree_indices'] for _ in batch], axis=0), dtype=torch.long)
        syn_rule_nl_left = torch.tensor(np.stack([_['syn_rule_nl_left'] for _ in batch], axis=0), dtype=torch.long)
        syn_rule_nl_right = torch.tensor(np.stack([_['syn_rule_nl_right'] for _ in batch], axis=0), dtype=torch.long)
        rea_tree_indices = torch.tensor(np.stack([_['rea_tree_indices'] for _ in batch], axis=0), dtype=torch.long)
        rea_rule_nl_left = torch.tensor(np.stack([_['rea_rule_nl_left'] for _ in batch], axis=0), dtype=torch.long)
        rea_rule_nl_right = torch.tensor(np.stack([_['rea_rule_nl_right'] for _ in batch], axis=0), dtype=torch.long)
        class_mask = torch.tensor(np.stack([_['class_mask'] for _ in batch], axis=0), dtype=torch.float32)
        query_paths = torch.tensor(np.stack([_['query_paths'] for _ in batch], axis=0), dtype=torch.long)
        labels = torch.tensor(np.stack([_['labels'] for _ in batch], axis=0), dtype=torch.long)
        parent_matrix = torch.tensor(np.stack([_['parent_matrix'] for _ in batch], axis=0), dtype=torch.float)
        syn_parent_matrix = torch.tensor(np.stack([_['syn_parent_matrix'] for _ in batch], axis=0), dtype=torch.float)
        path_lens = torch.tensor(np.stack([_['path_lens'] for _ in batch], axis=0), dtype=torch.long)
        syn_path_lens = torch.tensor(np.stack([_['syn_path_lens'] for _ in batch], axis=0), dtype=torch.long)

        return_dict = {'input_words': input_words,
                       'input_chars': input_chars,
                       'syn_tree_indices': syn_tree_indices,
                       'syn_rule_nl_left': syn_rule_nl_left,
                       'syn_rule_nl_right': syn_rule_nl_right,
                       'rea_tree_indices': rea_tree_indices,
                       'rea_rule_nl_left': rea_rule_nl_left,
                       'rea_rule_nl_right': rea_rule_nl_right,
                       'class_mask': class_mask,
                       'query_paths': query_paths,
                       'labels': labels,
                       'parent_matrix': parent_matrix,
                       'syn_parent_matrix': syn_parent_matrix,
                       'path_lens': path_lens,
                       'syn_path_lens': syn_path_lens}
        return return_dict


train_dataset=UsptoDataset("train_trees.csv")

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1, collate_fn=UsptoDataset.collate_fn)         

And when I use the dataloader as follows, it gives me different number of batches every epoch:

epoch_steps = len(train_loader)
for e in range(epochs):
    for j, batch_data in enumerate(train_loader):
        step = e * epoch_steps + j

The log shows that the first epoch only has 5 batches, the second epoch has 3 batches, and the third epoch has 5 batches, and so on.

 1 Config:
  2 Namespace(batch_size_per_gpu=4, epochs=400, eval_every_epoch=1, hidden_size=128, keep=10, log_every_step=1, lr=0.001, new_model=False, save_dir='saved_model/', workers=1)
  3 2021-01-06 15:33:17,909 - __main__ - WARNING - Checkpoints not found in dir saved_model/, creating a new model.
  4 2021-01-06 15:33:18,340 - __main__ - INFO - Step: 0, Loss: 5.4213, Rule acc: 0.1388
  5 2021-01-06 15:33:18,686 - __main__ - INFO - Step: 1, Loss: 4.884, Rule acc: 0.542
  6 2021-01-06 15:33:18,941 - __main__ - INFO - Step: 2, Loss: 4.6205, Rule acc: 0.6122
  7 2021-01-06 15:33:19,174 - __main__ - INFO - Step: 3, Loss: 4.4442, Rule acc: 0.61
  8 2021-01-06 15:33:19,424 - __main__ - INFO - Step: 4, Loss: 4.3033, Rule acc: 0.6211
  9 2021-01-06 15:33:20,684 - __main__ - INFO - Dev Loss: 3.5034, Dev Sample Acc: 0.0, Dev Rule Acc: 0.5970844200679234, in epoch 0
 10 2021-01-06 15:33:22,203 - __main__ - INFO - Test Loss: 3.4878, Test Sample Acc: 0.0, Test Rule Acc: 0.6470248053471247
 11 2021-01-06 15:33:22,394 - __main__ - INFO - Found better dev sample accuracy 0.0 in epoch 0
 12 2021-01-06 15:33:22,803 - __main__ - INFO - Step: 10002, Loss: 3.6232, Rule acc: 0.6555
 13 2021-01-06 15:33:23,046 - __main__ - INFO - Step: 10003, Loss: 3.53, Rule acc: 0.6442
 14 2021-01-06 15:33:23,286 - __main__ - INFO - Step: 10004, Loss: 3.4907, Rule acc: 0.6498
 15 2021-01-06 15:33:24,617 - __main__ - INFO - Dev Loss: 3.3081, Dev Sample Acc: 0.0, Dev Rule Acc: 0.5980878387178693, in epoch 1
 16 2021-01-06 15:33:26,215 - __main__ - INFO - Test Loss: 3.2859, Test Sample Acc: 0.0, Test Rule Acc: 0.6466992994149526
 17 2021-01-06 15:33:26,857 - __main__ - INFO - Step: 20004, Loss: 3.3965, Rule acc: 0.6493
 18 2021-01-06 15:33:27,093 - __main__ - INFO - Step: 20005, Loss: 3.3797, Rule acc: 0.6314
 19 2021-01-06 15:33:27,353 - __main__ - INFO - Step: 20006, Loss: 3.3959, Rule acc: 0.5727
 20 2021-01-06 15:33:27,609 - __main__ - INFO - Step: 20007, Loss: 3.3632, Rule acc: 0.6279
 21 2021-01-06 15:33:27,837 - __main__ - INFO - Step: 20008, Loss: 3.3331, Rule acc: 0.6158
 22 2021-01-06 15:33:29,122 - __main__ - INFO - Dev Loss: 3.0911, Dev Sample Acc: 0.0, Dev Rule Acc: 0.6016287207603455, in epoch 2
 23 2021-01-06 15:33:30,689 - __main__ - INFO - Test Loss: 3.0651, Test Sample Acc: 0.0, Test Rule Acc: 0.6531393428643545
 24 2021-01-06 15:33:32,143 - __main__ - INFO - Dev Loss: 3.0911, Dev Sample Acc: 0.0, Dev Rule Acc: 0.6016287207603455, in epoch 3
 25 2021-01-06 15:33:33,765 - __main__ - INFO - Test Loss: 3.0651, Test Sample Acc: 0.0, Test Rule Acc: 0.6531393428643545
 26 2021-01-06 15:33:34,359 - __main__ - INFO - Step: 40008, Loss: 3.108, Rule acc: 0.6816
 27 2021-01-06 15:33:34,604 - __main__ - INFO - Step: 40009, Loss: 3.0756, Rule acc: 0.6732
 28 2021-01-06 15:33:35,823 - __main__ - INFO - Dev Loss: 3.0419, Dev Sample Acc: 0.0, Dev Rule Acc: 0.613776079245976, in epoch 4

FYI, the value of len(train_loader.dataset), batch_size and len(train_loader) are 40008, 4 and 10002 respectively, which are exactly what I expected. So it is so confusing that using enumerate gives me the only several batches such as 3 or 5.

Are you using any break statements or other conditions inside the DataLoader loop?
If not, could you post an executable code snippet with random values for the dataset inputs, so that we could try to reproduce it?