Why does enumerating a dataloader show a different object than accessing the dataset by index?

I’m trying to use an existing model and adapt it for my own use case and dataset. To do that, I had to change the kind of information that the get_item method of the dataset returns (added some keys to a dictionary). For some reason, the new keys don’t get added when I try to loop over the dataloader that is supposed to load the dataset. During the debugging process I noticed that when accessing my dataset by index, it returns me the keys I added. When I access the dataset by using “enumerate”, it shows me the old keys.

Attached a code snippet from my trainer.fit() method with the outputs:

for batch_idx, batch in enumerate(train_loader):
                ic(type(train_loader))
                ic(type(train_loader.dataset))
                
                ic(train_loader.dataset[0].keys())
                ic(type(train_loader.dataset[0]))
    
                ic(batch.keys())
                ic(type(batch))

Outputs:

ic| type(train_loader): <class 'torch.utils.data.dataloader.DataLoader'>
ic| type(train_loader.dataset): <class 'torch.utils.data.dataset.ConcatDataset'>
ic| train_loader.dataset[0].keys(): dict_keys(['src_xyz', 'src_grey', 'tgt_xyz', 'tgt_grey', 'src_overlap', 'tgt_overlap', 'correspondences', 'pose', 'idx', 'overlap_p'])
ic| batch.keys(): dict_keys(['src_xyz', 'tgt_xyz', 'src_overlap', 'tgt_overlap', 'correspondences', 'idx', 'pose', 'overlap_p'])
ic| type(batch): <class 'dict'>

As you can see, the keys in train_loader.dataset[0] are different than in batch. Two questions arise from this:

  1. How is train_loader.dataset[0] not the same as the batch variable?
  2. Why do we enumerate over the dataloader, not over dataloader.dataset?

I hope the context is sufficient to understand the question!

  1. loader.dataset will access the dataset that was passed to the DataLoader in the main process. If you are using multiprocessing via num_workers>0 each worker will create its copy of this dataset during its instantiation. You could get information about the worker process via torch.utils.data.get_worker_info() inside the Dataset.__getitem__.

2.You don’t need to use a DataLoader, but it allows you to:

  • shuffle the data
  • use pinned memory
  • use multiprocessing via num_workers>0
  • batches the data
  • use custom samplers, etc.
1 Like

Thank you! It turned out there was a problem for the collate_fn.