New subset every epoch

I have a very big dataset, and I would like to use a different random subset for each epoch of 1000 samples. Is there any way I can do it using Dataset and Dataloader?
I would like something like torch.utils.data.RandomSampler but without replacement.

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=1,
    pin_memory=True,
    drop_last=True,
    sampler=SubsetRandomSampler(
        torch.randint(high=len(train_dataset), size=(1000,))
    ),
)

Edit: What I would really like is to have each epoch a maximum number of samples. So I have no problem of having all samples in dataset and randomly select 1000 samples each epoch.

Edit2: I came up with the following:

class RandomSampler(Sampler):
    def __init__(self, data_source, num_samples=None):
        self.data_source = data_source
        self._num_samples = num_samples

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError(
                "num_samples should be a positive integer "
                "value, but got num_samples={}".format(self.num_samples)
            )

    @property
    def num_samples(self):
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        return iter(torch.randperm(n, dtype=torch.int64)[: self.num_samples].tolist())

    def __len__(self):
        return self.num_samples

I edited the default RandomSampler in order to be able to samples without replacement, but i don’t know if this is the correct solution.

5 Likes

I think your first approach should also work, as SubsetRandomSampler doesn’t use replacement or did you see any issues using it?

But SubsetRandomSampler uses for all epochs the same subset samples. And what I would like is to use for every epoch a new random sampling of all the dataset.

1 Like

Ah yeah, sorry for not mentioning it, but you could recreate the DataLoader with a new sampler in each epoch, which should be cheap, if you are lazily loading the data.

Yes, I thought of doing that, but I wanted to create a DataLoader that could do it without recreating it. What do you think of the Edit2 in the first post? Would it do the trick?
I tried it, and I think it is working correctly. But I would like to be sure if it is genuinely random.

Your approach looks correct. To verify it, I would suggest to print the index in Dataset..__getitem__ for a couple of epochs and make sure that you are seeing a variety of indices.

2 Likes

One can also just run a break after the amount of data per epoch with shuffle on.
In this way one gets a random subsample of the whole data per epoch. And an idea could be that if you orginal have 1000it/epoch you can set the break efter 500 and twice the number of original epochs.

for epoch in range(num_epochs*2):
            runs = 0
            for item in tqdm(dataloader):
                runs = runs+1
                if runs>500:
                  break
1 Like

This approach only creates unique subset withing the epoch. However the net epoch __iter__ is called again in the next epoch and recreates a new subset that is not unique.

In essence this solution is the same as setting sampler=torch.utils.data.RandomSampler(dataset, replacement=False) in DataLoader

This won’t work since the above line reshuffles the data under the hood

Is there any way to create new subset without replacement (i.e., once a sample is used, it will not show up in the following epochs until the entire dataset has been iterated) every epoch?

Inspired by the set_epoch() of DistributedSampler, I add set_epoch() to the current RandomSampler, so that in each epoch it will use a different subset of the dataset. Can you review the code to see if it’s correct?

I think this is useful when the dataset is huge, in which case we prefer to separate a full epoch into several sub-epochs so that we can adjust learning rate at the right time (we usually adjust the learning rate when an epoch finished). It would be great if you can polish the code and commit it to the pytorch github.

class RandomSampler(Sampler[int]):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Args:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
        replacement_epoch (bool): samples are drawn on-demand with replacement in different epochs if ``True``, default=``False``
        drop_last (bool): set to ``True`` to drop the last incomplete epoch,
            if the dataset size is not divisible by the `num_samples`. If ``False`` and
            the size of dataset is not divisible by the `num_samples`, then the last epoch
            will be smaller. (default: ``False``)
        num_samples (int): number of samples to draw, default=`len(dataset)`.
        generator (Generator): Generator used in sampling.
    """
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 replacement_epoch: bool = False, drop_last: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self.replacement_epoch = replacement_epoch
        self.drop_last = drop_last
        self.epoch = 0
        self._num_samples = num_samples
        self.generator = generator

        if not isinstance(self.replacement, bool):
            raise TypeError("replacement should be a boolean value, but got "
                            "replacement={}".format(self.replacement))

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

        num_epoch = len(self.data_source) / num_samples
        self.num_epoch = max(math.floor(num_epoch), 1) if drop_last else math.ceil(num_epoch)
        if (not replacement) and (not replacement_epoch):
            if self.generator is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
                generator = torch.Generator()
                generator.manual_seed(seed)
            else:
                generator = self.generator
            n = len(self.data_source)
            self.randperm_list = torch.randperm(n, generator=generator).tolist()
        

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            if self.replacement_epoch:
                for _ in range(self.num_samples // n):
                    yield from torch.randperm(n, generator=generator).tolist()
                yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
            else:
                for _ in range(self.num_samples // n):
                    yield from torch.randperm(n, generator=generator).tolist()
                yield from self.randperm_list[self.num_samples * self.epoch:min(self.num_samples * (self.epoch + 1), n)]

    def __len__(self) -> int:
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        r"""
        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch % self.num_epoch
        if self.epoch == 0:
            if self.generator is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
                generator = torch.Generator()
                generator.manual_seed(seed)
            else:
                generator = self.generator
            n = len(self.data_source)
            self.randperm_list = torch.randperm(n, generator=generator).tolist()