How to avoid pitfalls when creating a new dataloader each epoch

I conduct experiments that rely on sampling data from a probability distribution that changes each epoch. Therefore, I use a custom DataSampler.

Based on what I found online, the only way to change the probability distribution each epoch is to create a new Dataloader each epoch as well.

However, while working as expected, seems to have a negatively impact the performance, as there is a significant delay between the end of one epoch and the start of a new one. In part, this is expected, because I draw samples from a large probability distribution (ca. 3M samples).

Currently, I use pin_workers=True in the dataloader, because I thought workers might be persistent between two successive instances of a Dataloader, but that seems to not be the case.

Some context regarding the dataset: the dataset itself is in n5 format and I load it from disk using zarr.


  1. Can I draw from a distribution that changes each epoch without creating a new dataloader every time?
  2. How do the Dataloader’s arguments persistent_workers, pin_memory affect my use case?
  3. Are there any tricks to improve performance or anything else to watch out regarding the frequent recreation of Dataloaders that I might be unaware of, given my explanation above?
  1. You could try to implement a custom sampler, which could use an internal epoch counter to change the distribution. This example shows a very simple approach:
class MySampler(
    def __init__(self, num_samples):
        self._num_samples = num_samples
        self.epoch = 0
    def num_samples(self):
        return self._num_samples

    def __iter__(self):
        n = self.num_samples
        if self.epoch == 0:
            print("creating samples for epoch 0")
            yield from torch.randperm(n).tolist()
            print("creating samples for epoch > 0")
            yield torch.ones(n).long()
        self.epoch += 1

    def __len__(self) -> int:
        return self.num_samples
sampler = MySampler(10)
dataset = TensorDataset(torch.randn(10, 1))
loader = DataLoader(dataset, sampler=sampler, batch_size=5)

for data in loader:
# creating samples for epoch 0
# [tensor([[-0.6384],
#         [-0.8028],
#         [-0.8999],
#         [-0.3053],
#         [ 0.1909]])]
# [tensor([[ 0.3961],
#         [-0.9236],
#         [-0.8927],
#         [-0.0187],
#         [-0.3265]])]

for data in loader:
# creating samples for epoch > 0
# [tensor([[[0.1909],
#          [0.1909],
#          [0.1909],
#          [0.1909],
#          [0.1909],
#          [0.1909],
#          [0.1909],
#          [0.1909],
#          [0.1909],
#          [0.1909]]])]

and might be useful as a template. Otherwise, you might need to recreate the DataLoader in each epoch.

  1. pin_memory is used to use page-locked host memory for the loaded and processed samples, which speeds up the data transfer and allow you to overlap the transfer with other workload. Persistent DataLoader workers will be kept alive between epochs, which could avoid the initial “epoch warmup”. If you want to recreate the DataLoader in each epoch, you should not use persistent workers as these would run in the background.

  2. Creating the DataLoader should be cheap assuming you are not preloading data in your Dataset. However, recreating would still trigger the warmup and you could try avoiding it by creating the iterator manually. This code snippet shows how the workers are prefetching the samples and creating the iterator via iter(loader) will directly start prefetching the samples.

1 Like

Thank you very much for addressing my questions! I will try out setting persistent_workers=False. I noticed some larger than usual memory consumption. Maybe some dangling workers are the source.

Thanks also for the code snippet. Unfortunately the distribution the sampler draws from changes based on the losses of the previous epoch, so I cannot pre-define its behaviour.

Update, maybe it helps others:
On a larger-scale training (still one machine) I had serious memory (RAM) issues, as the process kept allocating more and more memory. Setting persistent_workers=False alone did not solve the issue. However, when also setting pin_memory=True, the memory leak did not occur anymore.