The question I’m about to ask is probably not PyTorch-specific, but I encountered it in context of PyTorch DataLoader.
How do you properly add random perturbations when data is loaded and augmented by several processes?
Let me show on a simple example that this is not a trivial question. I have two files:
- augmentations.py:
import numpy as np
import os
class RandomAugmentation:
def __call__(self, obj):
perturbation = np.random.randint(10)
print(os.getpid(), perturbation)
return obj
- main.py
import numpy as np
from time import sleep
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Resize
from augmentations import RandomAugmentation
PATH = ...
transform = Compose([
RandomAugmentation(),
Resize((16, 16)),
ToTensor()
])
ds = ImageFolder(PATH, transform=transform)
dl = DataLoader(ds, batch_size=2, num_workers=3)
for epoch_nr in range(2):
for batch in dl:
break
sleep(1)
print('-' * 80)
In the main.py we generate batches of data like we would in a regular image recognition task. RandomAugmentation
prints output of RNG that would be used for data augmentation in a real task.
On my machine the output from running main.py was:
20909 6
20909 8
20908 6
20910 6
20908 8
20910 8
20909 7
20909 6
20908 7
20908 6
20910 7
20908 0
20910 6
20908 1
--------------------------------------------------------------------------------
20952 6
20953 6
20952 8
20952 7
20953 8
20952 6
20954 6
20952 0
20952 1
20953 7
20954 8
20953 6
20954 7
20954 6
--------------------------------------------------------------------------------
Clearly, i-th perturbation added by one worker is exactly the same as i-th perturbation added by any other worker (in this example 6, 8, 7, 6, and so on). What’s even worse, in the next epoch data loader is kind of reset, and generates exactly the same (perturbated) data as before, what completely ruins the entire idea of data augmentation.
For the time being I designed such a workaround:
augmentations.py
import numpy as np
from time import time
import os
RNG_PID = None
RNG = None
def get_rng():
global RNG
global RNG_PID
if os.getpid() != RNG_PID:
RNG = np.random.RandomState([int(time()), os.getpid()])
RNG_PID = os.getpid()
print("Initialize RNG", int(time()), os.getpid())
return RNG
class RandomAugmentation:
def __call__(self, obj):
rng = get_rng()
perturbation = rng.randint(10)
print(os.getpid(), perturbation)
return obj
However, it doesn’t seem like a very elegant solution to me.
It would be great to hear how you deal with this matter