Hi,
I am working with a custom dataloader and I wanted to know if it is necessary to create a class with __init__, __getitem__? Because I tried to create my own method that returns batches of data. Since I am using return instead of yield, I only get one batch of data instead of all the data.
However, if I use yield, I get the following error:
TypeError Traceback (most recent call last)
<ipython-input-35-30905f729d78> in <module>
----> 1 train_loss = training(model, train_dl, param.epochs)
<ipython-input-34-100bda3ef538> in training(model, train_loader, Epochs)
3 for epoch in range(Epochs):
4 running_loss = 0.0
----> 5 for data in train_loader:
6 img = data
7 img = img.to(device, dtype=torch.float)
~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __iter__(self)
277 return _SingleProcessDataLoaderIter(self)
278 else:
--> 279 return _MultiProcessingDataLoaderIter(self)
280
281 @property
~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __init__(self, loader)
744 # prime the prefetch loop
745 for _ in range(2 * self._num_workers):
--> 746 self._try_put_index()
747
748 def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _try_put_index(self)
859 assert self._tasks_outstanding < 2 * self._num_workers
860 try:
--> 861 index = self._next_index()
862 except StopIteration:
863 return
~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_index(self)
337
338 def _next_index(self):
--> 339 return next(self._sampler_iter) # may raise StopIteration
340
341 def _next_data(self):
~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/sampler.py in __iter__(self)
198 def __iter__(self):
199 batch = []
--> 200 for idx in self.sampler:
201 batch.append(idx)
202 if len(batch) == self.batch_size:
~/anaconda3/envs/work/lib/python3.8/site-packages/torch/utils/data/sampler.py in __iter__(self)
60
61 def __iter__(self):
---> 62 return iter(range(len(self.data_source)))
63
64 def __len__(self):
TypeError: object of type 'generator' has no len()
So is it really necessary to use the prescribed method or is there some way I can use yield?