In a typical training loop in PyTorch, we have a enumerate(dataloader) under epoch loop.
How can I change this to number of iterations rather than epochs?
In a typical training loop in PyTorch, we have a enumerate(dataloader) under epoch loop.
How can I change this to number of iterations rather than epochs?
something like this
tot_iter = 0
max_iterations = 100
while True: # if you want to iterate for inf epochs, use this endless loop
for batch_index, data in enumerate(data_loader, 0):
...training, backpropr, etc...
tot_iter += 1
if tot_iter >= max_iterations:
stop=True
break
if stop:
break
Yeah that works for sure. Thanks.
see this link:mmengine/loops.py at main Ā· open-mmlab/mmengine Ā· GitHub
the class ā_InfiniteDataloaderIteratorā may help you
class _InfiniteDataloaderIterator:
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
It resets the dataloader to continue iterating when the iterator has
iterated over all the data. However, this approach is not efficient, as the
workers need to be restarted every time the dataloader is reset. It is
recommended to use `mmengine.dataset.InfiniteSampler` to enable the
dataloader to iterate infinitely.
"""
def __init__(self, dataloader: DataLoader) -> None:
self._dataloader = dataloader
self._iterator = iter(self._dataloader)
self._epoch = 0
def __iter__(self):
return self
def __next__(self) -> Sequence[dict]:
try:
data = next(self._iterator)
except StopIteration:
print_log(
'Reach the end of the dataloader, it will be '
'restarted and continue to iterate. It is '
'recommended to use '
'`mmengine.dataset.InfiniteSampler` to enable the '
'dataloader to iterate infinitely.',
logger='current',
level=logging.WARNING)
self._epoch += 1
if hasattr(self._dataloader, 'sampler') and hasattr(
self._dataloader.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no sampler,
# or data loader uses `SequentialSampler` in Pytorch.
self._dataloader.sampler.set_epoch(self._epoch)
elif hasattr(self._dataloader, 'batch_sampler') and hasattr(
self._dataloader.batch_sampler.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no batch
# sampler. batch sampler in pytorch warps the sampler as its
# attributes.
self._dataloader.batch_sampler.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self._iterator = iter(self._dataloader)
data = next(self._iterator)
return data