DataLoader workers generate the same random augmentations

The question I’m about to ask is probably not PyTorch-specific, but I encountered it in context of PyTorch DataLoader.

How do you properly add random perturbations when data is loaded and augmented by several processes?

Let me show on a simple example that this is not a trivial question. I have two files:

  • augmentations.py:
import numpy as np
import os

class RandomAugmentation:
    def __call__(self, obj):
        perturbation = np.random.randint(10)
        print(os.getpid(), perturbation)
        return obj
  • main.py
import numpy as np
from time import sleep

from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Resize

from augmentations import RandomAugmentation

PATH = ...

transform = Compose([
    RandomAugmentation(),
    Resize((16, 16)),
    ToTensor()
])
ds = ImageFolder(PATH, transform=transform)
dl = DataLoader(ds, batch_size=2, num_workers=3)

for epoch_nr in range(2):
    for batch in dl:
        break
    sleep(1)
    print('-' * 80)

In the main.py we generate batches of data like we would in a regular image recognition task. RandomAugmentation prints output of RNG that would be used for data augmentation in a real task.

On my machine the output from running main.py was:

20909 6
20909 8
20908 6
20910 6
20908 8
20910 8
20909 7
20909 6
20908 7
20908 6
20910 7
20908 0
20910 6
20908 1
--------------------------------------------------------------------------------
20952 6
20953 6
20952 8
20952 7
20953 8
20952 6
20954 6
20952 0
20952 1
20953 7
20954 8
20953 6
20954 7
20954 6
--------------------------------------------------------------------------------

Clearly, i-th perturbation added by one worker is exactly the same as i-th perturbation added by any other worker (in this example 6, 8, 7, 6, and so on). What’s even worse, in the next epoch data loader is kind of reset, and generates exactly the same (perturbated) data as before, what completely ruins the entire idea of data augmentation.

For the time being I designed such a workaround:
augmentations.py

import numpy as np
from time import time
import os

RNG_PID = None
RNG = None


def get_rng():
    global RNG
    global RNG_PID
    if os.getpid() != RNG_PID:
        RNG = np.random.RandomState([int(time()), os.getpid()])
        RNG_PID = os.getpid()
        print("Initialize RNG", int(time()), os.getpid())
    return RNG

class RandomAugmentation:
    def __call__(self, obj):
        rng = get_rng()
        perturbation = rng.randint(10)
        print(os.getpid(), perturbation)
        return obj

However, it doesn’t seem like a very elegant solution to me.

It would be great to hear how you deal with this matter :slight_smile:

2 Likes

It’s a known issue, since you are using other libraries (numpy in your case) to generate random numbers.
Each worker will duplicate numpy’s PRNG, so that you’ll see the same numbers.
It’s described in the FAQ a bit better.

You could sample random numbers using torch.randint or use worker_init_fn to set the seed for each numpy process.

4 Likes

I think it would be good to clarify this in the official tutorial https://pytorch.org/tutorials/beginner/data_loading_tutorial.html since it also uses numpy’s random module.

That might be indeed a good idea. Would you be interested in implementing it? :slight_smile:

Ok I just submitted PR https://github.com/pytorch/tutorials/pull/1121 with the changes. I just added a note after the Transforms example clarifying the potential issue.

1 Like