rwightman
(Ross Wightman)
July 4, 2020, 1:02am
4
@vibe2 A researcher submitted a solution to this issue to my codebase a while back, apparently it has also been discussed in a pytorch issue/PR for sometime
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
This file has been truncated. show original
And the PR that was submitted with some info about the solution from the author: https://github.com/rwightman/pytorch-image-models/pull/140
1 Like