Reproducibility over multiGPUs is impossible until randomness of threads is controled, and yet

+++++++++++++++++++++++++++++++++++++++++++++++++++++++

                      PAGE 1/5

+++++++++++++++++++++++++++++++++++++++++++++++++++++++

Code tested using: Pytorch (1.0.0)/Python 3.7.0, over K=2 GPUs.

GIST-of-GIST:

  1. Depending on what your code is doing, it is possible to make the code reproducible over multigpu for a constant number of gpus K (i.e., results are reproducible only and only over K.)
  2. Controlling the randomness of each thread over each GPU allows reproducibilit (1) (conceptually, and over a synthetic code). However, using a real code, I was unable to obtain stable/reproducible results (sometime I obtain 100% reproduble results, other time, not the case). This must have something to do with the sensitivity of torch.nn.CrossEntropyLoss to the unstability/randomess of F.interpotale in my code as in this issue.
  3. In order to control randomness in threads when using torch.nn.DataParallel, we propose to use:
    a. threads.Lock to lock down the random regions within each thread to be thread-safe. (threads share random generators).
    b. Re-seed each thread separately. One can go further and pass around random prng state for each thread.
  4. I spent a lot of time of this issue. I decided to put this matter to bed for now. I hope that the dev-team will power Pytorch with simpler and more efficient tools for reproducibility over single and multiple gpus.
  5. Related: 1, 2.
  6. You may consider disabling Cudnn, if you are unable to get reproducible results: torch.backends.cudnn.enabled = False.

GIST:

After some digging, I came to this:

  1. If the forward function (or other functions) that you parallelized using
    torch.nn.DataParallel contains random instructions such as dropout, reproducibility over multiGPUs in Pytorch (1.0.0)/Python 3.7.0 is impossible for whatever number of GPUs K (in a sens, you can not obtain the same results for K=1, AND K=2, AND K=3, AND K=4, …).
  2. Why? because of multithreading which is non-deterministic by definition. torch.nn.parallel.parallel_apply.py
    uses threads. Threads share the memory between them (which includes random generator). Do you see now the problem? in our case, threads are truly running in parallel since each one is running over a GPU device (assuming there is only GPU instruction; and no CPU instructions). You can not know with certainty the order of the execution of the threads. It is up to the scheduler of OS’ kernel at runtime. Therefore, you can not determine the order of the threads’ calls to the random generator. Each call will change its internal state. As consequence, you can not determine at what state the thread i will call the random generator.
  3. The good news is that you can make your code reproducible only for K GPUs, in a sens, that the results obtained at K GPUs can be reproducible ONLY when using K GPUs.
  4. Reproducibility in Pytorch still needs a lot of work.
  5. Nightly build version 1.2.0.dev20190616 (https://download.pytorch.org/whl/nightly/cu100/torch_nightly-1.2.0.dev20190616-cp37-cp37m-linux_x86_64.whl) seems to have fixed a huge glitch in the rng. (merge)

I am not an expert in multithreading. I had to spend some time experimenting and testing to confirm the above logic. What is known about multithreads is that they share the memory; and each thread has its own stack. I am not sure what goes into the stack. Random generators are shared among the threads.

Here is some code to support the above. The batch size is 8. Therefore, each GPU will process 4 samples. The code has only one random instruction that happens in the forward. We perform only one call to the forward. The code consists in taking an input x, then adding to it a random value element-wise.

Code 1: threads share random generator.

This code has only and only two random output statutes. The output status depends on which thread, among the two threads, reaches FIRST to the random generator. This racing does not depend on you, your code, or how you designed it. It is up to the OS kernel scheduler. In every run, you may have a different results depending on the order which the threads were executed. Since we have only two threads, and only one random instruction. One call to the forward function will generate only two outcome states. We show the status of the random generator by displaying its signature before and after the calling the random generator. We assume that calling our random generator is
thread-safe (it means that if both threads call the random instruction in the same time nothing goes wrong.).

import random


import numpy as np


import torch
import torch.nn as nn
from torch.nn import DataParallel


def set_seed(seed):
    """
    For seed to some modules.
    :param seed: int. The seed.
    :return:
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x):
        if x.is_cuda:
            device = x.get_device()
            print("DEVICE: {}".format(x.get_device()))
            print("x.size() = {}".format(x.size()))
            print("x:\n {}".format(x))
        else:
            device = torch.device("cpu")

        prngs0 = torch.random.get_rng_state().type(torch.float).numpy()
        # computing sum(abs(diff)**2) is a better way to have a summary of the PRNGs instead of computing only the sum.
        # Two different PRNGs may have the same sum (may happen if we generate few random numbers).
        print("DEVICE {} PRNG STATUS BEFORE: \n {}".format(torch.cuda.current_device(),
                                                           np.abs(np.diff(prngs0)**2).sum()))

        delta = torch.rand(x.size()).to(device)
        prngs1 = torch.random.get_rng_state().type(torch.float).numpy()
        print("DEVICE {} PRNG STATUS AFTER: \n {}".format(torch.cuda.current_device(),
                                                          np.abs(np.diff(prngs1)**2).sum()))

        print("DEVICE {} DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM(): \n {}".format(
            torch.cuda.current_device(),   np.abs(prngs0 - prngs1).sum()))
        print("Delta:\n {}".format(delta))
        x = x + delta
        return x


if __name__ == "__main__":
    set_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Model()
    model = DataParallel(model)
    model.to(device)

    x = torch.rand(8, 3)
    x = x.to(device)
    print("X in:\n {}".format(x))
    print("X out:\n {}".format(model(x)))

run 1:

X in:
 tensor([[0.4963, 0.7682, 0.0885],
        [0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556],
        [0.6323, 0.3489, 0.4017],
        [0.0223, 0.1689, 0.2939],
        [0.5185, 0.6977, 0.8000],
        [0.1610, 0.2823, 0.6816],
        [0.9152, 0.3971, 0.8742]], device='cuda:0')
DEVICE: 0
x.size() = torch.Size([4, 3])
DEVICE: 1
x.size() = torch.Size([4, 3])
x:
 tensor([[0.0223, 0.1689, 0.2939],
        [0.5185, 0.6977, 0.8000],
        [0.1610, 0.2823, 0.6816],
        [0.9152, 0.3971, 0.8742]], device='cuda:1')
DEVICE 1 PRNG STATUS BEFORE: 
 46790328.0
DEVICE 1 PRNG STATUS AFTER: 
 46787832.0
DEVICE 1 DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM(): 
 24.0
Delta:
 tensor([[0.4194, 0.5529, 0.9527],
        [0.0362, 0.1852, 0.3734],
        [0.3051, 0.9320, 0.1759],
        [0.2698, 0.1507, 0.0317]], device='cuda:1')
x:
 tensor([[0.4963, 0.7682, 0.0885],
        [0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556],
        [0.6323, 0.3489, 0.4017]], device='cuda:0')
DEVICE 0 PRNG STATUS BEFORE: 
 46787832.0
DEVICE 0 PRNG STATUS AFTER: 
 46786488.0
DEVICE 0 DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM(): 
 24.0
Delta:
 tensor([[0.2081, 0.9298, 0.7231],
        [0.7423, 0.5263, 0.2437],
        [0.5846, 0.0332, 0.1387],
        [0.2422, 0.8155, 0.7932]], device='cuda:0')
X out:
 tensor([[0.7044, 1.6980, 0.8116],
        [0.8744, 0.8337, 0.8777],
        [1.0747, 0.9296, 0.5943],
        [0.8745, 1.1644, 1.1949],
        [0.4417, 0.7218, 1.2466],
        [0.5547, 0.8829, 1.1734],
        [0.4661, 1.2143, 0.8575],
        [1.1850, 0.5478, 0.9059]], device='cuda:0')

run 2:

X in:
 tensor([[0.4963, 0.7682, 0.0885],
        [0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556],
        [0.6323, 0.3489, 0.4017],
        [0.0223, 0.1689, 0.2939],
        [0.5185, 0.6977, 0.8000],
        [0.1610, 0.2823, 0.6816],
        [0.9152, 0.3971, 0.8742]], device='cuda:0')
DEVICE: 0
x.size() = torch.Size([4, 3])
DEVICE: 1
x.size() = torch.Size([4, 3])
x:
 tensor([[0.0223, 0.1689, 0.2939],
        [0.5185, 0.6977, 0.8000],
        [0.1610, 0.2823, 0.6816],
        [0.9152, 0.3971, 0.8742]], device='cuda:1')
x:
 tensor([[0.4963, 0.7682, 0.0885],
        [0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556],
        [0.6323, 0.3489, 0.4017]], device='cuda:0')
DEVICE 0 PRNG STATUS BEFORE: 
 46790328.0
DEVICE 1 PRNG STATUS BEFORE: 
 46790328.0
DEVICE 0 PRNG STATUS AFTER: 
 46786488.0
DEVICE 1 PRNG STATUS AFTER: 
 46786488.0
DEVICE 0 DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM(): 
 48.0
DEVICE 1 DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM(): 
 48.0
Delta:
 tensor([[0.4194, 0.5529, 0.9527],
        [0.0362, 0.1852, 0.3734],
        [0.3051, 0.9320, 0.1759],
        [0.2698, 0.1507, 0.0317]], device='cuda:0')
Delta:
 tensor([[0.2081, 0.9298, 0.7231],
        [0.7423, 0.5263, 0.2437],
        [0.5846, 0.0332, 0.1387],
        [0.2422, 0.8155, 0.7932]], device='cuda:1')
X out:
 tensor([[0.9157, 1.3211, 1.0412],
        [0.1682, 0.4927, 1.0075],
        [0.7952, 1.8284, 0.6315],
        [0.9021, 0.4996, 0.4334],
        [0.2305, 1.0987, 1.0170],
        [1.2609, 1.2240, 1.0437],
        [0.7456, 0.3154, 0.8203],
        [1.1574, 1.2126, 1.6673]], device='cuda:0')

These are the only possible outcomes. You see that there is only two unique deltas. Each one either generator by the first or the second thread depending on which call first torch.rand(). The signature of the PRNG is not perfect is a sens that two DIFFERENT PRNGs may have the same signature. It seems that the PRNG status is designed in some specific way with some properties. I didn’t spend much time to find a perfect unique signature.

Imagine you call the forward function 100 times. Predicting the output is impossible since it will be random. Now, you understand why it is impossible to obtain reproducible results when using multithreadings.