Is there any way to create new subset without replacement (i.e., once a sample is used, it will not show up in the following epochs until the entire dataset has been iterated) every epoch?
Inspired by the set_epoch()
of DistributedSampler
, I add set_epoch()
to the current RandomSampler
, so that in each epoch it will use a different subset of the dataset. Can you review the code to see if it’s correct?
I think this is useful when the dataset is huge, in which case we prefer to separate a full epoch into several sub-epochs so that we can adjust learning rate at the right time (we usually adjust the learning rate when an epoch finished). It would be great if you can polish the code and commit it to the pytorch github.
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
replacement_epoch (bool): samples are drawn on-demand with replacement in different epochs if ``True``, default=``False``
drop_last (bool): set to ``True`` to drop the last incomplete epoch,
if the dataset size is not divisible by the `num_samples`. If ``False`` and
the size of dataset is not divisible by the `num_samples`, then the last epoch
will be smaller. (default: ``False``)
num_samples (int): number of samples to draw, default=`len(dataset)`.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
replacement_epoch: bool = False, drop_last: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self.replacement_epoch = replacement_epoch
self.drop_last = drop_last
self.epoch = 0
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
num_epoch = len(self.data_source) / num_samples
self.num_epoch = max(math.floor(num_epoch), 1) if drop_last else math.ceil(num_epoch)
if (not replacement) and (not replacement_epoch):
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
n = len(self.data_source)
self.randperm_list = torch.randperm(n, generator=generator).tolist()
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
if self.replacement_epoch:
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
else:
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
yield from self.randperm_list[self.num_samples * self.epoch:min(self.num_samples * (self.epoch + 1), n)]
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch % self.num_epoch
if self.epoch == 0:
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
n = len(self.data_source)
self.randperm_list = torch.randperm(n, generator=generator).tolist()