Large differences in training runs depending on worker number

I ran the same script on identical data multiple times, started immediately after another.

the only difference is in the number of workers used, i.e.

gray = 0 workers
pink = 1 worker
blue = 2 workers
green 4 workers
orange is 8 workers

I have put determinism and random seeds on everywhere, except for seeding the workers.

The result in remarkable. Can anyone explain where this vast differences are coming from?

2 Likes

anyone an idea? I am experiencing this all the time

PAGE 1:

@ptrblck, sorry for tagging you out of context. I think this workaround may be helpful to add to Pytorch documentation as an illustration to show how to obtain reproducible results when using different number of workers (k>=0). The reproducibility here is not in the sens of obtaining the same splits (because it is done when using fixed seeds), but in the sens when the workers use random operations on samples. At this point, in Pytorch 1.0.0 (and more likely in Pytorch 1.0.1), it is not possible to produce exactly the same results. This post provide a simple workaround to deal with this. The main idea is to break the dependency of the state of the random generator of the worker and the sample.

Hello,

Please read the entire post before applying it. You can jump to the last paragraph for the conclusion.

Sorry for the long answer. I am interested in reproducibility. It bugs me that a code is not reproducible.

Thank you for pointing out this behavior.
Do not worry, this behavior is normal. The difference in your final results is not huge (the scale is misleading. Max difference ~4%). However, this may indicate that the method you are training may be sensitive to variations (randomness) (i.e., unstable method). I wonder why the curves do not have the same length in x-axis?
From research perspective, this is really bad news since one expects that using many workers is supposed ONLY to speedup computation and it does not have to affect in any way the outcome of the experimented algorithm. Pytorch dev team is doing great job. We will see below how to deal with this impractical behavior (if one cares about reproducibility).

The following findings are based on some of my experimentations following this thread. Feel free to correct me.

Here is the main keys:

a. A worker is dedicated to process a minibatch.
b. If you create k workers, and there is b minibatches, the workers will split the b minibatches equally (if possible) where each worker will process k/b minibatches.
c. Workers are created (forked) at each epoch.
d. The function worker_init_fn() is called ONLY ONCE when initializing the worker (at the forking moment to seed whatever modules you want at the function torch.utils.dataloader._worker_loop() which is the job of the worker). If you don’t use worker_init_fn() each worker will be seeded using bas_seed + i where i is the worker id in [0, k]. So, the first issue you see here is that the seed of the worker depends on its range, so if you have k=3 or k=10 workers, the workers will be seeded differently. This is a problem (fixed bellow). We do not want the randomness in the workers to depend on how many worker we use.
e. A worker process minbatches sequentially.

  1. Why this difference in the results?
    At the moment you create (fork) the workers (processes), each one will inherit the same state generator (of all the modules) of the main process. Lets call this state S0. For simplicity, let us assume that we have b=7, and each one has exactly 4 samples, and consider the two following cases:
    1.1 k=2: this means that the first worker will process 4 minibatches, and the second one 3 minibatches. Now, consider the sample i (index does not change because it is the index of the sample in the dataset) at the minibatch j (this index will change because the minibatches change: sample i maybe found in any minibatch j). This sample will be processed by the worker 0 with the state S0mij: the worker has already processed m samples before arriving to this sample, and this processing has changed the state of its internal generator to S0mij. I think now, you are getting where I am going with this when we change the number of workers.
    1.2 k=4: in this case, each worker will process 2 minibatches except the last one who will process only one. No, all the workers start with the same state S0. Assuming that the minibatch splitting is fixed (which is the case in practice if you fix the seeds). To show the issue, we consider that our sample i is located in the minibatch j=3. This means that in this case it is more likely to be process by another worker other than the worker 0 (I am not sure how the minibatches are split between the workers, but I assume first finish a minibatch, it starts the next nonprocessed one. But it does not matter at this point, we just assume that it will be processed by another worker). Assuming it is the worker 1 who will process this sample. Now, the ISSUE is that this worker will cross s samples before arriving to the sample i at the minibatch j. In other words, this worker will have the internal state S0sij which is different then S0mij. Therefore, the obtained results will certainly be different than in the previous case (k=2). That is why you obtained different results.

Imagine that your worker needs to randomly crop windows from an image. In this case, the locations of the cropped windows will be different when you use different number of workers (I verified this by experimenting on a real code using Pytorch 1.0.0. I checked the code in Pytorch 1.0.1, and it seems similar. So, I expect the same behavior).

Note:

  • num_workers=k where k>0: means that you will fork k NEW processes. When k=1 you will fork a NEW process aside the main process.
  • num_workers=0 you will use the main process.
  • In terms of speed, num_workers=0 and num_workers=1 are expected to have the same speed since your data is loaded by one process only.

Here a synthetic code to show this issue.

import sys

import numpy as np

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import Dataset


print("Pytorch: {}".format(torch.__version__))
print("Python: {}".format(sys.version))

x = np.arange(30)


def transform(v):
    return torch.rand(1).item()


class TestDataset(Dataset):
    def __len__(self):
        return 30

    def __getitem__(self, index):
        return transform(index), x[index]


dataset = TestDataset()
seed = 0
torch.manual_seed(seed)  # for reproducibility for the same run.
num_workers = 4  # number of workers
print("num_workers {}".format(num_workers))
loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=num_workers, drop_last=True)
for i in range(4):
    print("epoch {}".format(i + 1))
    iteration = 0
    for batch in loader:
        iteration += 1
        print("batch:", iteration, batch)

k=4:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 4
epoch 1
batch: 1 [tensor([0.7821, 0.0536, 0.9888, 0.1949], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.6938, 0.2980, 0.1669, 0.2847], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.6540, 0.2994, 0.2798, 0.5160], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.1343, 0.1992, 0.8930, 0.5375], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.5242, 0.1987, 0.5094, 0.7166], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.9157, 0.3889, 0.0907, 0.7386], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.3290, 0.1022, 0.0958, 0.8926], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
epoch 2
batch: 1 [tensor([0.0529, 0.4672, 0.8940, 0.4728], dtype=torch.float64), tensor([ 0, 26, 27,  4])]
batch: 2 [tensor([0.7667, 0.4851, 0.6022, 0.8482], dtype=torch.float64), tensor([11, 16, 22,  2])]
batch: 3 [tensor([0.8490, 0.9919, 0.7204, 0.3633], dtype=torch.float64), tensor([29,  5, 19, 12])]
batch: 4 [tensor([0.8680, 0.9830, 0.0914, 0.4104], dtype=torch.float64), tensor([17, 13,  3, 14])]
batch: 5 [tensor([0.3402, 0.5704, 0.2629, 0.5613], dtype=torch.float64), tensor([21, 20, 23,  7])]
batch: 6 [tensor([0.6301, 0.2657, 0.2147, 0.9602], dtype=torch.float64), tensor([ 9,  8,  6, 24])]
batch: 7 [tensor([0.1428, 0.6246, 0.5488, 0.6210], dtype=torch.float64), tensor([28,  1, 15, 18])]
epoch 3
batch: 1 [tensor([0.4653, 0.5769, 0.5840, 0.4835], dtype=torch.float64), tensor([25, 11, 16, 28])]
batch: 2 [tensor([0.9935, 0.0852, 0.0667, 0.5538], dtype=torch.float64), tensor([ 2, 17,  1,  9])]
batch: 3 [tensor([0.0709, 0.5544, 0.0166, 0.3393], dtype=torch.float64), tensor([22, 19, 18,  5])]
batch: 4 [tensor([0.8908, 0.7556, 0.4431, 0.2334], dtype=torch.float64), tensor([15,  0, 23, 20])]
batch: 5 [tensor([0.3239, 0.8920, 0.1505, 0.7293], dtype=torch.float64), tensor([26, 27,  4, 21])]
batch: 6 [tensor([0.2746, 0.8609, 0.2783, 0.7886], dtype=torch.float64), tensor([29, 14,  7,  3])]
batch: 7 [tensor([0.1664, 0.0153, 0.3084, 0.7510], dtype=torch.float64), tensor([13,  8, 12, 24])]
epoch 4
batch: 1 [tensor([0.6723, 0.7798, 0.2435, 0.6807], dtype=torch.float64), tensor([26, 17, 21, 29])]
batch: 2 [tensor([0.1031, 0.4875, 0.1857, 0.9544], dtype=torch.float64), tensor([24, 10,  4, 20])]
batch: 3 [tensor([0.1638, 0.3149, 0.1199, 0.1525], dtype=torch.float64), tensor([23, 11,  7,  5])]
batch: 4 [tensor([0.1885, 0.2642, 0.8981, 0.3819], dtype=torch.float64), tensor([ 1, 28, 18,  9])]
batch: 5 [tensor([0.2904, 0.1978, 0.6254, 0.6065], dtype=torch.float64), tensor([14, 27,  0, 19])]
batch: 6 [tensor([0.3774, 0.5173, 0.7019, 0.4410], dtype=torch.float64), tensor([ 6, 13, 16, 25])]
batch: 7 [tensor([0.8413, 0.0668, 0.7935, 0.5762], dtype=torch.float64), tensor([15,  8,  3,  2])]

k=2:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 2
epoch 1
batch: 1 [tensor([0.7821, 0.0536, 0.9888, 0.1949], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.6938, 0.2980, 0.1669, 0.2847], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.5242, 0.1987, 0.5094, 0.7166], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.9157, 0.3889, 0.0907, 0.7386], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.5961, 0.8303, 0.0050, 0.6031], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.1832, 0.6373, 0.9856, 0.6443], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.3745, 0.8614, 0.2829, 0.6196], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
epoch 2
batch: 1 [tensor([0.0529, 0.4672, 0.8940, 0.4728], dtype=torch.float64), tensor([ 0, 26, 27,  4])]
batch: 2 [tensor([0.7667, 0.4851, 0.6022, 0.8482], dtype=torch.float64), tensor([11, 16, 22,  2])]
batch: 3 [tensor([0.3402, 0.5704, 0.2629, 0.5613], dtype=torch.float64), tensor([29,  5, 19, 12])]
batch: 4 [tensor([0.6301, 0.2657, 0.2147, 0.9602], dtype=torch.float64), tensor([17, 13,  3, 14])]
batch: 5 [tensor([0.3510, 0.6561, 0.7647, 0.4369], dtype=torch.float64), tensor([21, 20, 23,  7])]
batch: 6 [tensor([0.6252, 0.4360, 0.2570, 0.9621], dtype=torch.float64), tensor([ 9,  8,  6, 24])]
batch: 7 [tensor([0.8529, 0.1900, 0.8626, 0.0526], dtype=torch.float64), tensor([28,  1, 15, 18])]
epoch 3
batch: 1 [tensor([0.4653, 0.5769, 0.5840, 0.4835], dtype=torch.float64), tensor([25, 11, 16, 28])]
batch: 2 [tensor([0.9935, 0.0852, 0.0667, 0.5538], dtype=torch.float64), tensor([ 2, 17,  1,  9])]
batch: 3 [tensor([0.3239, 0.8920, 0.1505, 0.7293], dtype=torch.float64), tensor([22, 19, 18,  5])]
batch: 4 [tensor([0.2746, 0.8609, 0.2783, 0.7886], dtype=torch.float64), tensor([15,  0, 23, 20])]
batch: 5 [tensor([0.4510, 0.3487, 0.2083, 0.4145], dtype=torch.float64), tensor([26, 27,  4, 21])]
batch: 6 [tensor([0.4875, 0.8162, 0.1422, 0.3001], dtype=torch.float64), tensor([29, 14,  7,  3])]
batch: 7 [tensor([0.0964, 0.8291, 0.8070, 0.5823], dtype=torch.float64), tensor([13,  8, 12, 24])]
epoch 4
batch: 1 [tensor([0.6723, 0.7798, 0.2435, 0.6807], dtype=torch.float64), tensor([26, 17, 21, 29])]
batch: 2 [tensor([0.1031, 0.4875, 0.1857, 0.9544], dtype=torch.float64), tensor([24, 10,  4, 20])]
batch: 3 [tensor([0.2904, 0.1978, 0.6254, 0.6065], dtype=torch.float64), tensor([23, 11,  7,  5])]
batch: 4 [tensor([0.3774, 0.5173, 0.7019, 0.4410], dtype=torch.float64), tensor([ 1, 28, 18,  9])]
batch: 5 [tensor([0.8905, 0.4018, 0.8140, 0.2980], dtype=torch.float64), tensor([14, 27,  0, 19])]
batch: 6 [tensor([0.2089, 0.5475, 0.8681, 0.2968], dtype=torch.float64), tensor([ 6, 13, 16, 25])]
batch: 7 [tensor([0.0501, 0.3870, 0.3744, 0.0609], dtype=torch.float64), tensor([15,  8,  3,  2])]
  • You notice that the splits are the same.
  • You notice also that the two first minibatches have the same generated numbers independently of k, EVERY EPOCH. WHY? because in the case of k=2 or k=4, the firs processed minibatches (worker 0 in each case) starts with the main process state S0 (since the number of minibatches per worker is different depending on k). The following minibatch will be processed differently because another worker will process it. In this example, the minibatch 3 will be processed by the worker 0 (most likely) when k=2, and by a woker (different than 0 most likely) when k=4. This uncertainty is due to the racing between the processes, and how the minibatches are split between workers.
  1. (sens of your question): Is there a correlation between good/bad results and the number of workers?
    From my understanding, absolutely not. What you obtained is absolutely random in a sens if you repeat your experiment many-many times, you will end up by uniform results in a sens that every number of workers you use may lead equally to good/bad results. You can repeat your experiment (if possible) a couple of times and see that the pattern you obtained in this figure will not be repeated since it is totally random. (you can show us the obtained figures)

  2. How to obtain the same results when using different number of worker k>0:
    As I mentioned, Pytorch dev team is doing a great job, and I think this may not a priority, even though it is very important in term of reproducibility. The user can do a simple workaround to fix this issue. The punch line: we will attribute to each sample within the dataset its specific random seed that we will use to reinitialize the internal random state of the worker in order to break the dependency between the number of the workers and the state of the worker when processing the sample in hand. The random seed will change in each epoch.. Here is a gist of the code:

import random
import sys

import numpy as np
import torch

from torch.utils.data.dataloader import DataLoader
from torch.utils.data import Dataset

print("Pytorch: {}".format(torch.__version__))
print("Python: {}".format(sys.version))

np.random.seed(0)

x = np.arange(30)


def _init_fn(worker_id):
    # Not necessary since we will reinitialize the internal generator state of the worker at EVERY SAMPLE!
    pass


def transform(v):
    return torch.rand(1).item()


class TestDataset(Dataset):
    def __init__(self):
        super(TestDataset, self).__init__()
        self.seeds = None
        self.set_up_new_seeds()  # set up seeds for the initialization.

    def set_up_new_seeds(self):
        self.seeds = self.get_new_seeds()

    def get_new_seeds(self):
        return np.random.randint(0, 100, len(self))

    def __len__(self):
        return 30

    def __getitem__(self, index):
        # Set the seed for this sample: Seed ALL THE MODULES within the worker that need seeding.
        # In this example: we seed only torch. If you use numpy or other modules to load your samples, you need to
        # seed them as well in this place.
        seed = self.seeds[index]
        torch.manual_seed(seed)

        return transform(index), x[index]


dataset = TestDataset()
seed = 0
torch.manual_seed(seed)  # for reproducibility for the same run.
num_workers = 2
print("num_workers {}".format(num_workers))
loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=num_workers, drop_last=True, worker_init_fn=_init_fn)
for i in range(4):
    print("epoch {}".format(i + 1))
    iteration = 0
    # Initialize the seeds before creating the workers:
    dataset.set_up_new_seeds()

    for batch in loader:
        iteration += 1
        print("batch:", iteration, batch)
    print("Seeds at epoch {}: {}".format(iteration, dataset.seeds))

Here are the results with k=2, k=4:

k=4:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 4
epoch 1
batch: 1 [tensor([0.8718, 0.4978, 0.7638, 0.7456], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.9686, 0.9686, 0.7731, 0.4963], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.4578, 0.0530, 0.0492, 0.4283], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.6275, 0.6487, 0.9731, 0.8757], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.6558, 0.4963, 0.2298, 0.3615], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.8357, 0.5695, 0.2621, 0.1033], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.1226, 0.5019, 0.8757, 0.0036], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
Seeds at epoch 7: [47 64 82 99 88 49 29 19 19 14 39 32 65  9 57 32 31 74 23 35 75 55 28 34
  0  0 36 53  5 38]
epoch 2
batch: 1 [tensor([0.4342, 0.0530, 0.0043, 0.4670], dtype=torch.float64), tensor([ 0, 26, 27,  4])]
batch: 2 [tensor([0.1490, 0.5695, 0.6487, 0.5596], dtype=torch.float64), tensor([11, 16, 22,  2])]
batch: 3 [tensor([0.7162, 0.2298, 0.4657, 0.6611], dtype=torch.float64), tensor([29,  5, 19, 12])]
batch: 4 [tensor([0.1033, 0.4578, 0.8823, 0.4569], dtype=torch.float64), tensor([17, 13,  3, 14])]
batch: 5 [tensor([0.7865, 0.8823, 0.3991, 0.9731], dtype=torch.float64), tensor([21, 20, 23,  7])]
batch: 6 [tensor([0.3615, 0.2364, 0.7576, 0.5722], dtype=torch.float64), tensor([ 9,  8,  6, 24])]
batch: 7 [tensor([0.5724, 0.1710, 0.4963, 0.7456], dtype=torch.float64), tensor([28,  1, 15, 18])]
Seeds at epoch 7: [17 79  4 42 58 31  1 65 41 57 35 11 46 82 91  0 14 99 53 12 42 84 75 68
  6 68 47  3 76 52]
epoch 3
batch: 1 [tensor([0.4604, 0.8398, 0.8398, 0.4736], dtype=torch.float64), tensor([25, 11, 16, 28])]
batch: 2 [tensor([0.5615, 0.1628, 0.2973, 0.4775], dtype=torch.float64), tensor([ 2, 17,  1,  9])]
batch: 3 [tensor([0.4775, 0.6180, 0.4963, 0.4283], dtype=torch.float64), tensor([22, 19, 18,  5])]
batch: 4 [tensor([0.5737, 0.3344, 0.2267, 0.4978], dtype=torch.float64), tensor([15,  0, 23, 20])]
batch: 5 [tensor([0.8823, 0.2919, 0.4670, 0.8718], dtype=torch.float64), tensor([26, 27,  4, 21])]
batch: 6 [tensor([0.5286, 0.0492, 0.0918, 0.1033], dtype=torch.float64), tensor([29, 14,  7,  3])]
batch: 7 [tensor([0.2621, 0.8157, 0.2364, 0.0043], dtype=torch.float64), tensor([13,  8, 12, 24])]
Seeds at epoch 7: [78 15 20 99 58 23 79 13 85 48 49 69 41 35 64 95 69 94  0 50 36 34 48 93
  3 98 42 77 21 73]
epoch 4
batch: 1 [tensor([0.4540, 0.6180, 0.4670, 0.9872], dtype=torch.float64), tensor([26, 17, 21, 29])]
batch: 2 [tensor([0.4581, 0.1628, 0.4283, 0.2364], dtype=torch.float64), tensor([24, 10,  4, 20])]
batch: 3 [tensor([0.4978, 0.7380, 0.4604, 0.5530], dtype=torch.float64), tensor([23, 11,  7,  5])]
batch: 4 [tensor([0.4581, 0.6147, 0.9847, 0.2621], dtype=torch.float64), tensor([ 1, 28, 18,  9])]
batch: 5 [tensor([0.1033, 0.1490, 0.4963, 0.5695], dtype=torch.float64), tensor([14, 27,  0, 19])]
batch: 6 [tensor([0.6147, 0.6611, 0.9873, 0.8674], dtype=torch.float64), tensor([ 6, 13, 16, 25])]
batch: 7 [tensor([0.5615, 0.6412, 0.4670, 0.4540], dtype=torch.float64), tensor([15,  8,  3,  2])]
Seeds at epoch 7: [ 0 10 43 58 23 59  2 98 62 35 94 67 82 46 99 20 81 50 27 14 41 58 65 36
 10 86 43 11  2 51]

k=2:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 2
epoch 1
batch: 1 [tensor([0.8718, 0.4978, 0.7638, 0.7456], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.9686, 0.9686, 0.7731, 0.4963], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.4578, 0.0530, 0.0492, 0.4283], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.6275, 0.6487, 0.9731, 0.8757], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.6558, 0.4963, 0.2298, 0.3615], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.8357, 0.5695, 0.2621, 0.1033], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.1226, 0.5019, 0.8757, 0.0036], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
Seeds at epoch 7: [47 64 82 99 88 49 29 19 19 14 39 32 65  9 57 32 31 74 23 35 75 55 28 34
  0  0 36 53  5 38]
epoch 2
batch: 1 [tensor([0.4342, 0.0530, 0.0043, 0.4670], dtype=torch.float64), tensor([ 0, 26, 27,  4])]
batch: 2 [tensor([0.1490, 0.5695, 0.6487, 0.5596], dtype=torch.float64), tensor([11, 16, 22,  2])]
batch: 3 [tensor([0.7162, 0.2298, 0.4657, 0.6611], dtype=torch.float64), tensor([29,  5, 19, 12])]
batch: 4 [tensor([0.1033, 0.4578, 0.8823, 0.4569], dtype=torch.float64), tensor([17, 13,  3, 14])]
batch: 5 [tensor([0.7865, 0.8823, 0.3991, 0.9731], dtype=torch.float64), tensor([21, 20, 23,  7])]
batch: 6 [tensor([0.3615, 0.2364, 0.7576, 0.5722], dtype=torch.float64), tensor([ 9,  8,  6, 24])]
batch: 7 [tensor([0.5724, 0.1710, 0.4963, 0.7456], dtype=torch.float64), tensor([28,  1, 15, 18])]
Seeds at epoch 7: [17 79  4 42 58 31  1 65 41 57 35 11 46 82 91  0 14 99 53 12 42 84 75 68
  6 68 47  3 76 52]
epoch 3
batch: 1 [tensor([0.4604, 0.8398, 0.8398, 0.4736], dtype=torch.float64), tensor([25, 11, 16, 28])]
batch: 2 [tensor([0.5615, 0.1628, 0.2973, 0.4775], dtype=torch.float64), tensor([ 2, 17,  1,  9])]
batch: 3 [tensor([0.4775, 0.6180, 0.4963, 0.4283], dtype=torch.float64), tensor([22, 19, 18,  5])]
batch: 4 [tensor([0.5737, 0.3344, 0.2267, 0.4978], dtype=torch.float64), tensor([15,  0, 23, 20])]
batch: 5 [tensor([0.8823, 0.2919, 0.4670, 0.8718], dtype=torch.float64), tensor([26, 27,  4, 21])]
batch: 6 [tensor([0.5286, 0.0492, 0.0918, 0.1033], dtype=torch.float64), tensor([29, 14,  7,  3])]
batch: 7 [tensor([0.2621, 0.8157, 0.2364, 0.0043], dtype=torch.float64), tensor([13,  8, 12, 24])]
Seeds at epoch 7: [78 15 20 99 58 23 79 13 85 48 49 69 41 35 64 95 69 94  0 50 36 34 48 93
  3 98 42 77 21 73]
epoch 4
batch: 1 [tensor([0.4540, 0.6180, 0.4670, 0.9872], dtype=torch.float64), tensor([26, 17, 21, 29])]
batch: 2 [tensor([0.4581, 0.1628, 0.4283, 0.2364], dtype=torch.float64), tensor([24, 10,  4, 20])]
batch: 3 [tensor([0.4978, 0.7380, 0.4604, 0.5530], dtype=torch.float64), tensor([23, 11,  7,  5])]
batch: 4 [tensor([0.4581, 0.6147, 0.9847, 0.2621], dtype=torch.float64), tensor([ 1, 28, 18,  9])]
batch: 5 [tensor([0.1033, 0.1490, 0.4963, 0.5695], dtype=torch.float64), tensor([14, 27,  0, 19])]
batch: 6 [tensor([0.6147, 0.6611, 0.9873, 0.8674], dtype=torch.float64), tensor([ 6, 13, 16, 25])]
batch: 7 [tensor([0.5615, 0.6412, 0.4670, 0.4540], dtype=torch.float64), tensor([15,  8,  3,  2])]
Seeds at epoch 7: [ 0 10 43 58 23 59  2 98 62 35 94 67 82 46 99 20 81 50 27 14 41 58 65 36
 10 86 43 11  2 51]

Why we need to change the seeds every epoch? it is not a necessary, but, we did that to break any cycling. For instance, if you use gradient descent in training, it is better to through some randomness in every epoch. Changing the seeds is a way to do that.

This code does not obtain the same results if k=0. Only the first epoch is the same (because internal state of the workers and the main process is the same. After the first epoch, the worker internal state is reinitialized, while the main process keeps going!!)!!!

k=0:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 0
epoch 1
batch: 1 [tensor([0.8718, 0.4978, 0.7638, 0.7456], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.9686, 0.9686, 0.7731, 0.4963], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.4578, 0.0530, 0.0492, 0.4283], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.6275, 0.6487, 0.9731, 0.8757], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.6558, 0.4963, 0.2298, 0.3615], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.8357, 0.5695, 0.2621, 0.1033], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.1226, 0.5019, 0.8757, 0.0036], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
Seeds at epoch 7: [47 64 82 99 88 49 29 19 19 14 39 32 65  9 57 32 31 74 23 35 75 55 28 34
  0  0 36 53  5 38]
epoch 2
batch: 1 [tensor([0.5724, 0.4963, 0.6611, 0.0530], dtype=torch.float64), tensor([28, 15, 12, 26])]
batch: 2 [tensor([0.1490, 0.2364, 0.8823, 0.3991], dtype=torch.float64), tensor([11,  8,  3, 23])]
batch: 3 [tensor([0.0043, 0.4670, 0.3615, 0.4578], dtype=torch.float64), tensor([27,  4,  9, 13])]
batch: 4 [tensor([0.3991, 0.4342, 0.7576, 0.2621], dtype=torch.float64), tensor([25,  0,  6, 10])]
batch: 5 [tensor([0.5695, 0.5596, 0.7456, 0.1710], dtype=torch.float64), tensor([16,  2, 18,  1])]
batch: 6 [tensor([0.4569, 0.5722, 0.8823, 0.2298], dtype=torch.float64), tensor([14, 24, 20,  5])]
batch: 7 [tensor([0.6487, 0.9731, 0.1033, 0.4657], dtype=torch.float64), tensor([22,  7, 17, 19])]
Seeds at epoch 7: [17 79  4 42 58 31  1 65 41 57 35 11 46 82 91  0 14 99 53 12 42 84 75 68
  6 68 47  3 76 52]
epoch 3
batch: 1 [tensor([0.2919, 0.4775, 0.8398, 0.4963], dtype=torch.float64), tensor([27,  9, 16, 18])]
batch: 2 [tensor([0.8718, 0.1033, 0.1710, 0.4736], dtype=torch.float64), tensor([21,  3,  6, 28])]
batch: 3 [tensor([0.6180, 0.0469, 0.5737, 0.0043], dtype=torch.float64), tensor([19, 10, 15, 24])]
batch: 4 [tensor([0.2973, 0.0492, 0.4604, 0.0918], dtype=torch.float64), tensor([ 1, 14, 25,  7])]
batch: 5 [tensor([0.8398, 0.5615, 0.8157, 0.4670], dtype=torch.float64), tensor([11,  2,  8,  4])]
batch: 6 [tensor([0.1628, 0.4978, 0.4283, 0.3344], dtype=torch.float64), tensor([17, 20,  5,  0])]
batch: 7 [tensor([0.5286, 0.2364, 0.8823, 0.2621], dtype=torch.float64), tensor([29, 12, 26, 13])]
Seeds at epoch 7: [78 15 20 99 58 23 79 13 85 48 49 69 41 35 64 95 69 94  0 50 36 34 48 93
  3 98 42 77 21 73]
epoch 4
batch: 1 [tensor([0.1490, 0.1628, 0.9731, 0.4581], dtype=torch.float64), tensor([27, 10, 22, 24])]
batch: 2 [tensor([0.6412, 0.4978, 0.4540, 0.6180], dtype=torch.float64), tensor([ 8, 23,  2, 17])]
batch: 3 [tensor([0.4963, 0.4578, 0.6147, 0.4581], dtype=torch.float64), tensor([ 0, 12,  6,  1])]
batch: 4 [tensor([0.8674, 0.4283, 0.4540, 0.2364], dtype=torch.float64), tensor([25,  4, 26, 20])]
batch: 5 [tensor([0.2621, 0.5530, 0.7380, 0.6147], dtype=torch.float64), tensor([ 9,  5, 11, 28])]
batch: 6 [tensor([0.6611, 0.4604, 0.5615, 0.5695], dtype=torch.float64), tensor([13,  7, 15, 19])]
batch: 7 [tensor([0.9847, 0.4670, 0.9872, 0.9873], dtype=torch.float64), tensor([18,  3, 29, 16])]
Seeds at epoch 7: [ 0 10 43 58 23 59  2 98 62 35 94 67 82 46 99 20 81 50 27 14 41 58 65 36
 10 86 43 11  2 51]

See next post: PAGE 2 (maximum number of characters: 32000).

3 Likes

PAGE 2

  1. OK, I see what you did there. But, How to obtain the same results when using different number of worker k>=0:
    As we saw earlier, the internal state of the main process and the workers evolve differently. If one want to control the state of the main process, we need to reseed its state every epoch as well. The main process may do many things and change its many of its modules’ internal states. So, we need to reseed all possible modules that inrroduce randomness. Here is the code:
import random
import sys

import numpy as np
import torch

from torch.utils.data.dataloader import DataLoader
from torch.utils.data import Dataset

print("Pytorch: {}".format(torch.__version__))
print("Python: {}".format(sys.version))

np.random.seed(0)

x = np.arange(30)


def _init_fn(worker_id):
    # Not necessary since we will reinitialize the internal generator state of the worker at EVERY SAMPLE!
    pass


def transform(v):
    return torch.rand(1).item()


class TestDataset(Dataset):
    def __init__(self):
        super(TestDataset, self).__init__()
        self.seeds = None
        self.set_up_new_seeds()  # set up seeds for the initialization.

    def set_up_new_seeds(self):
        self.seeds = self.get_new_seeds()

    def get_new_seeds(self):
        return np.random.randint(0, 100, len(self))

    def __len__(self):
        return 30

    def __getitem__(self, index):
        # Set the seed for this sample: Seed ALL THE MODULES within the worker that need seeding.
        # In this example: we seed only torch. If you use numpy or other modules to load your samples, you need to
        # seed them as well in this place.
        seed = self.seeds[index]
        torch.manual_seed(seed)

        return transform(index), x[index]


dataset = TestDataset()
seed = 0
torch.manual_seed(seed)  # for reproducibility for the same run.
num_workers = 4
print("num_workers {}".format(num_workers))
loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=num_workers, drop_last=True, worker_init_fn=_init_fn)
for i in range(4):
    print("epoch {}".format(i + 1))
    iteration = 0
    # Reseed the modules that introduce randomness: for the case num_workers = 0 (the main process).
    np.random.seed(i)
    random.seed(i)
    torch.manual_seed(i)
    # Initialize the seeds before creating the workers:
    dataset.set_up_new_seeds()

    for batch in loader:
        iteration += 1
        print("batch:", iteration, batch)
    print("Seeds at epoch {}: {}".format(iteration, dataset.seeds))

k=4:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 4
epoch 1
batch: 1 [tensor([0.2919, 0.5615, 0.0036, 0.6104], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.4736, 0.4978, 0.7380, 0.6558], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.0492, 0.7196, 0.0530, 0.6611], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.7518, 0.9873, 0.7731, 0.7731], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.4657, 0.3250, 0.0036, 0.4670], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.7911, 0.0036, 0.7731, 0.7380], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.8029, 0.1710, 0.9731, 0.4350], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
Seeds at epoch 7: [44 47 64 67 67  9 83 21 36 87 70 88 88 12 58 65 39 87 46 88 81 37 25 77
 72  9 20 80 69 79]
epoch 2
batch: 1 [tensor([0.0036, 0.1625, 0.8303, 0.6180], dtype=torch.float64), tensor([24,  8,  5, 22])]
batch: 2 [tensor([0.5073, 0.5695, 0.0918, 0.7865], dtype=torch.float64), tensor([11, 21, 29, 17])]
batch: 3 [tensor([0.8674, 0.5615, 0.1628, 0.1710], dtype=torch.float64), tensor([28, 15, 26,  6])]
batch: 4 [tensor([0.0036, 0.9598, 0.1226, 0.3250], dtype=torch.float64), tensor([25, 27, 20,  2])]
batch: 5 [tensor([0.6180, 0.6275, 0.7911, 0.7518], dtype=torch.float64), tensor([14, 19,  0, 13])]
batch: 6 [tensor([0.5724, 0.4657, 0.5722, 0.7576], dtype=torch.float64), tensor([10,  1, 12,  9])]
batch: 7 [tensor([0.3991, 0.0492, 0.6558, 0.1490], dtype=torch.float64), tensor([23,  7,  3, 18])]
Seeds at epoch 7: [37 12 72  9 75  5 79 64 16  1 76 71  6 25 50 20 18 84 11 28 29 14 50 68
 87 87 94 96 86 13]
epoch 3
batch: 1 [tensor([0.6487, 0.7911, 0.7380, 0.2298], dtype=torch.float64), tensor([11, 18, 20, 15])]
batch: 2 [tensor([0.8718, 0.3250, 0.8089, 0.5596], dtype=torch.float64), tensor([ 8,  2, 14, 21])]
batch: 3 [tensor([0.4540, 0.0036, 0.7731, 0.0530], dtype=torch.float64), tensor([ 4, 19, 29, 13])]
batch: 4 [tensor([0.3659, 0.0469, 0.5737, 0.5615], dtype=torch.float64), tensor([ 3,  9, 10, 17])]
batch: 5 [tensor([0.4578, 0.8157, 0.8823, 0.5349], dtype=torch.float64), tensor([ 5, 12, 22,  7])]
batch: 6 [tensor([0.9872, 0.1819, 0.6186, 0.7380], dtype=torch.float64), tensor([23, 16, 25, 27])]
batch: 7 [tensor([0.8398, 0.6487, 0.2973, 0.5019], dtype=torch.float64), tensor([28,  6,  1, 24])]
Seeds at epoch 7: [40 15 72 22 43 82 75  7 34 49 95 75 85 47 63 31 90 20 37 39 67  4 42 51
 38 33 58 67 69 88]
epoch 4
batch: 1 [tensor([0.7638, 0.4736, 0.9531, 0.2267], dtype=torch.float64), tensor([ 7, 10,  2, 15])]
batch: 2 [tensor([0.4976, 0.4736, 0.9873, 0.5019], dtype=torch.float64), tensor([22,  5, 19, 11])]
batch: 3 [tensor([0.0036, 0.3250, 0.4581, 0.0043], dtype=torch.float64), tensor([16,  3,  9,  1])]
batch: 4 [tensor([0.9598, 0.3659, 0.5615, 0.8089], dtype=torch.float64), tensor([12, 21, 13, 24])]
batch: 5 [tensor([0.4963, 0.9872, 0.8313, 0.7576], dtype=torch.float64), tensor([ 4, 27, 18, 26])]
batch: 6 [tensor([0.7196, 0.2364, 0.1819, 0.0236], dtype=torch.float64), tensor([14,  8, 20, 25])]
batch: 7 [tensor([0.8398, 0.7644, 0.1819, 0.6147], dtype=torch.float64), tensor([29,  0, 28, 23])]
Seeds at epoch 7: [24  3 56 72  0 21 19 74 41 10 21 38 96 20 44 93 39 14 26 81 90 22 66  2
 63 60  1 51 90 69]

k=2:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 2
epoch 1
batch: 1 [tensor([0.2919, 0.5615, 0.0036, 0.6104], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.4736, 0.4978, 0.7380, 0.6558], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.0492, 0.7196, 0.0530, 0.6611], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.7518, 0.9873, 0.7731, 0.7731], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.4657, 0.3250, 0.0036, 0.4670], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.7911, 0.0036, 0.7731, 0.7380], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.8029, 0.1710, 0.9731, 0.4350], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
Seeds at epoch 7: [44 47 64 67 67  9 83 21 36 87 70 88 88 12 58 65 39 87 46 88 81 37 25 77
 72  9 20 80 69 79]
epoch 2
batch: 1 [tensor([0.0036, 0.1625, 0.8303, 0.6180], dtype=torch.float64), tensor([24,  8,  5, 22])]
batch: 2 [tensor([0.5073, 0.5695, 0.0918, 0.7865], dtype=torch.float64), tensor([11, 21, 29, 17])]
batch: 3 [tensor([0.8674, 0.5615, 0.1628, 0.1710], dtype=torch.float64), tensor([28, 15, 26,  6])]
batch: 4 [tensor([0.0036, 0.9598, 0.1226, 0.3250], dtype=torch.float64), tensor([25, 27, 20,  2])]
batch: 5 [tensor([0.6180, 0.6275, 0.7911, 0.7518], dtype=torch.float64), tensor([14, 19,  0, 13])]
batch: 6 [tensor([0.5724, 0.4657, 0.5722, 0.7576], dtype=torch.float64), tensor([10,  1, 12,  9])]
batch: 7 [tensor([0.3991, 0.0492, 0.6558, 0.1490], dtype=torch.float64), tensor([23,  7,  3, 18])]
Seeds at epoch 7: [37 12 72  9 75  5 79 64 16  1 76 71  6 25 50 20 18 84 11 28 29 14 50 68
 87 87 94 96 86 13]
epoch 3
batch: 1 [tensor([0.6487, 0.7911, 0.7380, 0.2298], dtype=torch.float64), tensor([11, 18, 20, 15])]
batch: 2 [tensor([0.8718, 0.3250, 0.8089, 0.5596], dtype=torch.float64), tensor([ 8,  2, 14, 21])]
batch: 3 [tensor([0.4540, 0.0036, 0.7731, 0.0530], dtype=torch.float64), tensor([ 4, 19, 29, 13])]
batch: 4 [tensor([0.3659, 0.0469, 0.5737, 0.5615], dtype=torch.float64), tensor([ 3,  9, 10, 17])]
batch: 5 [tensor([0.4578, 0.8157, 0.8823, 0.5349], dtype=torch.float64), tensor([ 5, 12, 22,  7])]
batch: 6 [tensor([0.9872, 0.1819, 0.6186, 0.7380], dtype=torch.float64), tensor([23, 16, 25, 27])]
batch: 7 [tensor([0.8398, 0.6487, 0.2973, 0.5019], dtype=torch.float64), tensor([28,  6,  1, 24])]
Seeds at epoch 7: [40 15 72 22 43 82 75  7 34 49 95 75 85 47 63 31 90 20 37 39 67  4 42 51
 38 33 58 67 69 88]
epoch 4
batch: 1 [tensor([0.7638, 0.4736, 0.9531, 0.2267], dtype=torch.float64), tensor([ 7, 10,  2, 15])]
batch: 2 [tensor([0.4976, 0.4736, 0.9873, 0.5019], dtype=torch.float64), tensor([22,  5, 19, 11])]
batch: 3 [tensor([0.0036, 0.3250, 0.4581, 0.0043], dtype=torch.float64), tensor([16,  3,  9,  1])]
batch: 4 [tensor([0.9598, 0.3659, 0.5615, 0.8089], dtype=torch.float64), tensor([12, 21, 13, 24])]
batch: 5 [tensor([0.4963, 0.9872, 0.8313, 0.7576], dtype=torch.float64), tensor([ 4, 27, 18, 26])]
batch: 6 [tensor([0.7196, 0.2364, 0.1819, 0.0236], dtype=torch.float64), tensor([14,  8, 20, 25])]
batch: 7 [tensor([0.8398, 0.7644, 0.1819, 0.6147], dtype=torch.float64), tensor([29,  0, 28, 23])]
Seeds at epoch 7: [24  3 56 72  0 21 19 74 41 10 21 38 96 20 44 93 39 14 26 81 90 22 66  2
 63 60  1 51 90 69]

k=0:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 0
epoch 1
batch: 1 [tensor([0.2919, 0.5615, 0.0036, 0.6104], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.4736, 0.4978, 0.7380, 0.6558], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.0492, 0.7196, 0.0530, 0.6611], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.7518, 0.9873, 0.7731, 0.7731], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.4657, 0.3250, 0.0036, 0.4670], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.7911, 0.0036, 0.7731, 0.7380], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.8029, 0.1710, 0.9731, 0.4350], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
Seeds at epoch 7: [44 47 64 67 67  9 83 21 36 87 70 88 88 12 58 65 39 87 46 88 81 37 25 77
 72  9 20 80 69 79]
epoch 2
batch: 1 [tensor([0.0036, 0.1625, 0.8303, 0.6180], dtype=torch.float64), tensor([24,  8,  5, 22])]
batch: 2 [tensor([0.5073, 0.5695, 0.0918, 0.7865], dtype=torch.float64), tensor([11, 21, 29, 17])]
batch: 3 [tensor([0.8674, 0.5615, 0.1628, 0.1710], dtype=torch.float64), tensor([28, 15, 26,  6])]
batch: 4 [tensor([0.0036, 0.9598, 0.1226, 0.3250], dtype=torch.float64), tensor([25, 27, 20,  2])]
batch: 5 [tensor([0.6180, 0.6275, 0.7911, 0.7518], dtype=torch.float64), tensor([14, 19,  0, 13])]
batch: 6 [tensor([0.5724, 0.4657, 0.5722, 0.7576], dtype=torch.float64), tensor([10,  1, 12,  9])]
batch: 7 [tensor([0.3991, 0.0492, 0.6558, 0.1490], dtype=torch.float64), tensor([23,  7,  3, 18])]
Seeds at epoch 7: [37 12 72  9 75  5 79 64 16  1 76 71  6 25 50 20 18 84 11 28 29 14 50 68
 87 87 94 96 86 13]
epoch 3
batch: 1 [tensor([0.6487, 0.7911, 0.7380, 0.2298], dtype=torch.float64), tensor([11, 18, 20, 15])]
batch: 2 [tensor([0.8718, 0.3250, 0.8089, 0.5596], dtype=torch.float64), tensor([ 8,  2, 14, 21])]
batch: 3 [tensor([0.4540, 0.0036, 0.7731, 0.0530], dtype=torch.float64), tensor([ 4, 19, 29, 13])]
batch: 4 [tensor([0.3659, 0.0469, 0.5737, 0.5615], dtype=torch.float64), tensor([ 3,  9, 10, 17])]
batch: 5 [tensor([0.4578, 0.8157, 0.8823, 0.5349], dtype=torch.float64), tensor([ 5, 12, 22,  7])]
batch: 6 [tensor([0.9872, 0.1819, 0.6186, 0.7380], dtype=torch.float64), tensor([23, 16, 25, 27])]
batch: 7 [tensor([0.8398, 0.6487, 0.2973, 0.5019], dtype=torch.float64), tensor([28,  6,  1, 24])]
Seeds at epoch 7: [40 15 72 22 43 82 75  7 34 49 95 75 85 47 63 31 90 20 37 39 67  4 42 51
 38 33 58 67 69 88]
epoch 4
batch: 1 [tensor([0.7638, 0.4736, 0.9531, 0.2267], dtype=torch.float64), tensor([ 7, 10,  2, 15])]
batch: 2 [tensor([0.4976, 0.4736, 0.9873, 0.5019], dtype=torch.float64), tensor([22,  5, 19, 11])]
batch: 3 [tensor([0.0036, 0.3250, 0.4581, 0.0043], dtype=torch.float64), tensor([16,  3,  9,  1])]
batch: 4 [tensor([0.9598, 0.3659, 0.5615, 0.8089], dtype=torch.float64), tensor([12, 21, 13, 24])]
batch: 5 [tensor([0.4963, 0.9872, 0.8313, 0.7576], dtype=torch.float64), tensor([ 4, 27, 18, 26])]
batch: 6 [tensor([0.7196, 0.2364, 0.1819, 0.0236], dtype=torch.float64), tensor([14,  8, 20, 25])]
batch: 7 [tensor([0.8398, 0.7644, 0.1819, 0.6147], dtype=torch.float64), tensor([29,  0, 28, 23])]
Seeds at epoch 7: [24  3 56 72  0 21 19 74 41 10 21 38 96 20 44 93 39 14 26 81 90 22 66  2
 63 60  1 51 90 69]

k=1:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 1
epoch 1
batch: 1 [tensor([0.2919, 0.5615, 0.0036, 0.6104], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.4736, 0.4978, 0.7380, 0.6558], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.0492, 0.7196, 0.0530, 0.6611], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.7518, 0.9873, 0.7731, 0.7731], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.4657, 0.3250, 0.0036, 0.4670], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.7911, 0.0036, 0.7731, 0.7380], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.8029, 0.1710, 0.9731, 0.4350], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
Seeds at epoch 7: [44 47 64 67 67  9 83 21 36 87 70 88 88 12 58 65 39 87 46 88 81 37 25 77
 72  9 20 80 69 79]
epoch 2
batch: 1 [tensor([0.0036, 0.1625, 0.8303, 0.6180], dtype=torch.float64), tensor([24,  8,  5, 22])]
batch: 2 [tensor([0.5073, 0.5695, 0.0918, 0.7865], dtype=torch.float64), tensor([11, 21, 29, 17])]
batch: 3 [tensor([0.8674, 0.5615, 0.1628, 0.1710], dtype=torch.float64), tensor([28, 15, 26,  6])]
batch: 4 [tensor([0.0036, 0.9598, 0.1226, 0.3250], dtype=torch.float64), tensor([25, 27, 20,  2])]
batch: 5 [tensor([0.6180, 0.6275, 0.7911, 0.7518], dtype=torch.float64), tensor([14, 19,  0, 13])]
batch: 6 [tensor([0.5724, 0.4657, 0.5722, 0.7576], dtype=torch.float64), tensor([10,  1, 12,  9])]
batch: 7 [tensor([0.3991, 0.0492, 0.6558, 0.1490], dtype=torch.float64), tensor([23,  7,  3, 18])]
Seeds at epoch 7: [37 12 72  9 75  5 79 64 16  1 76 71  6 25 50 20 18 84 11 28 29 14 50 68
 87 87 94 96 86 13]
epoch 3
batch: 1 [tensor([0.6487, 0.7911, 0.7380, 0.2298], dtype=torch.float64), tensor([11, 18, 20, 15])]
batch: 2 [tensor([0.8718, 0.3250, 0.8089, 0.5596], dtype=torch.float64), tensor([ 8,  2, 14, 21])]
batch: 3 [tensor([0.4540, 0.0036, 0.7731, 0.0530], dtype=torch.float64), tensor([ 4, 19, 29, 13])]
batch: 4 [tensor([0.3659, 0.0469, 0.5737, 0.5615], dtype=torch.float64), tensor([ 3,  9, 10, 17])]
batch: 5 [tensor([0.4578, 0.8157, 0.8823, 0.5349], dtype=torch.float64), tensor([ 5, 12, 22,  7])]
batch: 6 [tensor([0.9872, 0.1819, 0.6186, 0.7380], dtype=torch.float64), tensor([23, 16, 25, 27])]
batch: 7 [tensor([0.8398, 0.6487, 0.2973, 0.5019], dtype=torch.float64), tensor([28,  6,  1, 24])]
Seeds at epoch 7: [40 15 72 22 43 82 75  7 34 49 95 75 85 47 63 31 90 20 37 39 67  4 42 51
 38 33 58 67 69 88]
epoch 4
batch: 1 [tensor([0.7638, 0.4736, 0.9531, 0.2267], dtype=torch.float64), tensor([ 7, 10,  2, 15])]
batch: 2 [tensor([0.4976, 0.4736, 0.9873, 0.5019], dtype=torch.float64), tensor([22,  5, 19, 11])]
batch: 3 [tensor([0.0036, 0.3250, 0.4581, 0.0043], dtype=torch.float64), tensor([16,  3,  9,  1])]
batch: 4 [tensor([0.9598, 0.3659, 0.5615, 0.8089], dtype=torch.float64), tensor([12, 21, 13, 24])]
batch: 5 [tensor([0.4963, 0.9872, 0.8313, 0.7576], dtype=torch.float64), tensor([ 4, 27, 18, 26])]
batch: 6 [tensor([0.7196, 0.2364, 0.1819, 0.0236], dtype=torch.float64), tensor([14,  8, 20, 25])]
batch: 7 [tensor([0.8398, 0.7644, 0.1819, 0.6147], dtype=torch.float64), tensor([29,  0, 28, 23])]
Seeds at epoch 7: [24  3 56 72  0 21 19 74 41 10 21 38 96 20 44 93 39 14 26 81 90 22 66  2
 63 60  1 51 90 69]

k=3:

Pytorch: 1.0.0
Python: 3.7.0 | packaged by conda-forge | (default, Nov 12 2018, 20:15:55)
[GCC 7.3.0]
num_workers 3
epoch 1
batch: 1 [tensor([0.2919, 0.5615, 0.0036, 0.6104], dtype=torch.float64), tensor([23, 26, 17, 27])]
batch: 2 [tensor([0.4736, 0.4978, 0.7380, 0.6558], dtype=torch.float64), tensor([ 7,  8,  4, 25])]
batch: 3 [tensor([0.0492, 0.7196, 0.0530, 0.6611], dtype=torch.float64), tensor([ 2,  0,  1, 18])]
batch: 4 [tensor([0.7518, 0.9873, 0.7731, 0.7731], dtype=torch.float64), tensor([22, 20, 12, 11])]
batch: 5 [tensor([0.4657, 0.3250, 0.0036, 0.4670], dtype=torch.float64), tensor([13, 24, 16, 14])]
batch: 6 [tensor([0.7911, 0.0036, 0.7731, 0.7380], dtype=torch.float64), tensor([21,  9, 19,  3])]
batch: 7 [tensor([0.8029, 0.1710, 0.9731, 0.4350], dtype=torch.float64), tensor([ 6, 29, 15, 10])]
Seeds at epoch 7: [44 47 64 67 67  9 83 21 36 87 70 88 88 12 58 65 39 87 46 88 81 37 25 77
 72  9 20 80 69 79]
epoch 2
batch: 1 [tensor([0.0036, 0.1625, 0.8303, 0.6180], dtype=torch.float64), tensor([24,  8,  5, 22])]
batch: 2 [tensor([0.5073, 0.5695, 0.0918, 0.7865], dtype=torch.float64), tensor([11, 21, 29, 17])]
batch: 3 [tensor([0.8674, 0.5615, 0.1628, 0.1710], dtype=torch.float64), tensor([28, 15, 26,  6])]
batch: 4 [tensor([0.0036, 0.9598, 0.1226, 0.3250], dtype=torch.float64), tensor([25, 27, 20,  2])]
batch: 5 [tensor([0.6180, 0.6275, 0.7911, 0.7518], dtype=torch.float64), tensor([14, 19,  0, 13])]
batch: 6 [tensor([0.5724, 0.4657, 0.5722, 0.7576], dtype=torch.float64), tensor([10,  1, 12,  9])]
batch: 7 [tensor([0.3991, 0.0492, 0.6558, 0.1490], dtype=torch.float64), tensor([23,  7,  3, 18])]
Seeds at epoch 7: [37 12 72  9 75  5 79 64 16  1 76 71  6 25 50 20 18 84 11 28 29 14 50 68
 87 87 94 96 86 13]
epoch 3
batch: 1 [tensor([0.6487, 0.7911, 0.7380, 0.2298], dtype=torch.float64), tensor([11, 18, 20, 15])]
batch: 2 [tensor([0.8718, 0.3250, 0.8089, 0.5596], dtype=torch.float64), tensor([ 8,  2, 14, 21])]
batch: 3 [tensor([0.4540, 0.0036, 0.7731, 0.0530], dtype=torch.float64), tensor([ 4, 19, 29, 13])]
batch: 4 [tensor([0.3659, 0.0469, 0.5737, 0.5615], dtype=torch.float64), tensor([ 3,  9, 10, 17])]
batch: 5 [tensor([0.4578, 0.8157, 0.8823, 0.5349], dtype=torch.float64), tensor([ 5, 12, 22,  7])]
batch: 6 [tensor([0.9872, 0.1819, 0.6186, 0.7380], dtype=torch.float64), tensor([23, 16, 25, 27])]
batch: 7 [tensor([0.8398, 0.6487, 0.2973, 0.5019], dtype=torch.float64), tensor([28,  6,  1, 24])]
Seeds at epoch 7: [40 15 72 22 43 82 75  7 34 49 95 75 85 47 63 31 90 20 37 39 67  4 42 51
 38 33 58 67 69 88]
epoch 4
batch: 1 [tensor([0.7638, 0.4736, 0.9531, 0.2267], dtype=torch.float64), tensor([ 7, 10,  2, 15])]
batch: 2 [tensor([0.4976, 0.4736, 0.9873, 0.5019], dtype=torch.float64), tensor([22,  5, 19, 11])]
batch: 3 [tensor([0.0036, 0.3250, 0.4581, 0.0043], dtype=torch.float64), tensor([16,  3,  9,  1])]
batch: 4 [tensor([0.9598, 0.3659, 0.5615, 0.8089], dtype=torch.float64), tensor([12, 21, 13, 24])]
batch: 5 [tensor([0.4963, 0.9872, 0.8313, 0.7576], dtype=torch.float64), tensor([ 4, 27, 18, 26])]
batch: 6 [tensor([0.7196, 0.2364, 0.1819, 0.0236], dtype=torch.float64), tensor([14,  8, 20, 25])]
batch: 7 [tensor([0.8398, 0.7644, 0.1819, 0.6147], dtype=torch.float64), tensor([29,  0, 28, 23])]
Seeds at epoch 7: [24  3 56 72  0 21 19 74 41 10 21 38 96 20 44 93 39 14 26 81 90 22 66  2
 63 60  1 51 90 69]

I tested a on real code X (Pytorch 1.0.0) where the worker need to do some random operations on each sample such as cropping, rotation, and so on.
Repeating the same code X twice (same device: GPU id) using the same number of workers leads to the same exact results (100% reproducible). However, changing the number of workers leads to entirely different results (normal behavior). Using the workaround mentioned above: the code X remains reproducible when repeated using the same number of workers, and produces exactly the same results whatever the number of workers k>=0. HOWEVER, the example discussed in this post is very basic and shows only the reproducibility over the dataloader. In practice, the main process does many other things that involves randomness (which we didn’t show here). Such randomness will cause a problem when k=0 compared to k>0.

Now, we understand that the job of a worker is only to load samples. The reste of the code is done by the main process. The thing that we didn’t show in the above example, and which is the case in practice, is the following:

# P0
for i, (data, labels) in enumerate(dataloader):  # calling workers.
  # P1
  # do some random stuff: main process.

In this case, the state of the internal generator of the main process at P1 will depend on k:

  • If k>0: the state at P0 is the same at P1.
  • If k=0: the state at P0 is different than the state at P1 because it is the main process who loads the samples. Therefore, if that loading needs randomness, the state of the main process will be different after loading the samples.

To avoid such issue, you need to make sure that after calling the code of the workers, you need to reset the state of the main process to a fixed state (it does not matter if it changes from iteration to another). The most important is that the state of the main process at P1 is the same independently of k. Here is a snippet:

# P0
for i, (data, labels) in enumerate(dataloader):  # calling workers.
  # P1: re-seed the state of the main process to something that does not depend on the number k.
  # do some random stuff: main process.

You need to do this re-seeding of the state of the main process at every access to the dataloader (every call to the workers). In practice, this is mostly done when looping over minibatches (train/evaluation).

8 Likes

Thanks for nice explaination. I am croping the image before input.

  1. Not used “worker_init_fn=_init_fn”.
  2. Pass epoch as seed number as below code.
for epoch in range(1,10):
  # P0
  torch.manual_seed(epoch)
  torch.cuda.manual_seed(epoch)
  for i, (data, labels) in enumerate(dataloader):  # calling workers.
    # P1: 
    # do some random stuff: main process.
  1. Random crop
      x = int(torch.randint(np.maximum(0, new_w - opt.crop_size),(1,)))
      y = int(torch.randint(np.maximum(0, new_h - opt.crop_size),(1,)))

Below show the crop value ( xy is = crop) for worker 0 and 2. For simplify taken only starting some value and batch.

# num of worker=2
*** xy is  62 18
*** xy is  2 92
*** xy is  3 96
*** xy is  72 79

(epoch: 1 iters: 8, time: 0.188, data: 0.450) 
(epoch: 1 iters: 16 time: 0.073, data: 0.004)
# num of worker=0
*** xy is  127 9
*** xy is  75 5
*** xy is  79 64
*** xy is  16 1
*** xy is  76 71
*** xy is  109 124
*** xy is  6 25
*** xy is  50 20
(epoch: 1, iters: 8, time: 0.142, data: 0.199) 
*** xy is  126 101
*** xy is  18 84
*** xy is  11 124
*** xy is  106 28
*** xy is  29 14
*** xy is  50 68
*** xy is  87 87
*** xy is  105 113
(epoch: 1, iters: 16, time: 0.057, data: 0.159) 

Below show the PSNR graph with num of worker=0 and 2



For num of worker =0, PSNR SSIM increase as loss decrease and epoch progress. While for num of worker=2 its fluctuating. I have taken 80 training images with batch size=8 and testing image =10 for analysis.

Is it possible PSNR follow the graph like this? Because as per my understanding num of worker just create the batch of input and pass to main process. It should not affect the PSNR graph.

This state i think only change the crop value i think or am i missing something?