Create a custom dataloader


I am working with a custom dataloader and I wanted to know if it is necessary to create a class with __init__, __getitem__? Because I tried to create my own method that returns batches of data. Since I am using return instead of yield, I only get one batch of data instead of all the data.

However, if I use yield, I get the following error:

TypeError                                 Traceback (most recent call last)
<ipython-input-35-30905f729d78> in <module>
----> 1 train_loss = training(model, train_dl, param.epochs)

<ipython-input-34-100bda3ef538> in training(model, train_loader, Epochs)
      3     for epoch in range(Epochs):
      4         running_loss = 0.0
----> 5         for data in train_loader:
      6             img = data
      7             img =, dtype=torch.float)

~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/ in __iter__(self)
    277             return _SingleProcessDataLoaderIter(self)
    278         else:
--> 279             return _MultiProcessingDataLoaderIter(self)
    281     @property

~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/ in __init__(self, loader)
    744         # prime the prefetch loop
    745         for _ in range(2 * self._num_workers):
--> 746             self._try_put_index()
    748     def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):

~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/ in _try_put_index(self)
    859         assert self._tasks_outstanding < 2 * self._num_workers
    860         try:
--> 861             index = self._next_index()
    862         except StopIteration:
    863             return

~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/ in _next_index(self)
    338     def _next_index(self):
--> 339         return next(self._sampler_iter)  # may raise StopIteration
    341     def _next_data(self):

~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/ in __iter__(self)
    198     def __iter__(self):
    199         batch = []
--> 200         for idx in self.sampler:
    201             batch.append(idx)
    202             if len(batch) == self.batch_size:

~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/ in __iter__(self)
     61     def __iter__(self):
---> 62         return iter(range(len(self.data_source)))
     64     def __len__(self):

TypeError: object of type 'generator' has no len()

So is it really necessary to use the prescribed method or is there some way I can use yield?

Could you describe your use case and why you need to create a custom DataLoader?
Usually you would create a custom Dataset (as described here) and, if necessary, write a custom collate_fn for the DataLoader, but wouldn’t need to manipulate the DataLoader directly.

1 Like


Thank you for the reply and my sincere apologies for the late response. That’s kinda why I want to create a custom dataloader because I am not working with a generic dataset and getting the dataset into a generic one will take a lot of time.

I think I was looking for this only, collate_fn. Probably this is why I am getting this error.