+++++++++++++++++++++++++++++++++++++++++++++++++++++++
PAGE 1/5
+++++++++++++++++++++++++++++++++++++++++++++++++++++++
Code tested using: Pytorch (1.0.0)
/Python 3.7.0
, over K=2
GPUs.
GIST-of-GIST:
- Depending on what your code is doing, it is possible to make the code reproducible over multigpu for a constant number of gpus
K
(i.e., results are reproducible only and only overK
.) - Controlling the randomness of each thread over each GPU allows reproducibilit (1) (conceptually, and over a synthetic code). However, using a
real
code, I was unable to obtain stable/reproducible results (sometime I obtain 100% reproduble results, other time, not the case). This must have something to do with the sensitivity oftorch.nn.CrossEntropyLoss
to the unstability/randomess ofF.interpotale
in my code as in this issue. - In order to control randomness in threads when using
torch.nn.DataParallel
, we propose to use:
a.threads.Lock
to lock down the random regions within each thread to be thread-safe. (threads share random generators).
b. Re-seed each thread separately. One can go further and pass around random prng state for each thread. - I spent a lot of time of this issue. I decided to put this matter to bed for now. I hope that the dev-team will power Pytorch with simpler and more efficient tools for reproducibility over single and multiple gpus.
- Related: 1, 2.
- You may consider disabling Cudnn, if you are unable to get reproducible results:
torch.backends.cudnn.enabled = False
.
GIST:
After some digging, I came to this:
-
If the forward function (or other functions) that you parallelized using
torch.nn.DataParallel contains random instructions such as dropout, reproducibility over multiGPUs inPytorch (1.0.0)
/Python 3.7.0
is impossible for whatever number of GPUsK
(in a sens, you can not obtain the same results forK=1
, ANDK=2
, ANDK=3
, ANDK=4
, …). - Why? because of multithreading which is non-deterministic by definition. torch.nn.parallel.parallel_apply.py
uses threads. Threads share the memory between them (which includes random generator). Do you see now the problem? in our case, threads are truly running in parallel since each one is running over a GPU device (assuming there is only GPU instruction; and no CPU instructions). You can not know with certainty the order of the execution of the threads. It is up to the scheduler of OS’ kernel at runtime. Therefore, you can not determine the order of the threads’ calls to the random generator. Each call will change its internal state. As consequence, you can not determine at what state the threadi
will call the random generator. - The good news is that you can make your code reproducible only for
K
GPUs, in a sens, that the results obtained atK
GPUs can be reproducible ONLY when usingK
GPUs. - Reproducibility in Pytorch still needs a lot of work.
- Nightly build version
1.2.0.dev20190616
(https://download.pytorch.org/whl/nightly/cu100/torch_nightly-1.2.0.dev20190616-cp37-cp37m-linux_x86_64.whl) seems to have fixed a huge glitch in the rng. (merge)
I am not an expert in multithreading. I had to spend some time experimenting and testing to confirm the above logic. What is known about multithreads is that they share the memory; and each thread has its own stack. I am not sure what goes into the stack. Random generators are shared among the threads.
Here is some code to support the above. The batch size is 8
. Therefore, each GPU will process 4
samples. The code has only one random instruction that happens in the forward. We perform only one call to the forward. The code consists in taking an input x
, then adding to it a random value element-wise.
Code 1: threads share random generator.
This code has only and only two random output statutes. The output status depends on which thread, among the two threads, reaches FIRST to the random generator. This racing does not depend on you, your code, or how you designed it. It is up to the OS kernel scheduler. In every run, you may have a different results depending on the order which the threads were executed. Since we have only two threads, and only one random instruction. One call to the forward function will generate only two outcome states. We show the status of the random generator by displaying its signature before and after the calling the random generator. We assume that calling our random generator is
thread-safe (it means that if both threads call the random instruction in the same time nothing goes wrong.).
import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import DataParallel
def set_seed(seed):
"""
For seed to some modules.
:param seed: int. The seed.
:return:
"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
if x.is_cuda:
device = x.get_device()
print("DEVICE: {}".format(x.get_device()))
print("x.size() = {}".format(x.size()))
print("x:\n {}".format(x))
else:
device = torch.device("cpu")
prngs0 = torch.random.get_rng_state().type(torch.float).numpy()
# computing sum(abs(diff)**2) is a better way to have a summary of the PRNGs instead of computing only the sum.
# Two different PRNGs may have the same sum (may happen if we generate few random numbers).
print("DEVICE {} PRNG STATUS BEFORE: \n {}".format(torch.cuda.current_device(),
np.abs(np.diff(prngs0)**2).sum()))
delta = torch.rand(x.size()).to(device)
prngs1 = torch.random.get_rng_state().type(torch.float).numpy()
print("DEVICE {} PRNG STATUS AFTER: \n {}".format(torch.cuda.current_device(),
np.abs(np.diff(prngs1)**2).sum()))
print("DEVICE {} DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM(): \n {}".format(
torch.cuda.current_device(), np.abs(prngs0 - prngs1).sum()))
print("Delta:\n {}".format(delta))
x = x + delta
return x
if __name__ == "__main__":
set_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model = DataParallel(model)
model.to(device)
x = torch.rand(8, 3)
x = x.to(device)
print("X in:\n {}".format(x))
print("X out:\n {}".format(model(x)))
run 1:
X in:
tensor([[0.4963, 0.7682, 0.0885],
[0.1320, 0.3074, 0.6341],
[0.4901, 0.8964, 0.4556],
[0.6323, 0.3489, 0.4017],
[0.0223, 0.1689, 0.2939],
[0.5185, 0.6977, 0.8000],
[0.1610, 0.2823, 0.6816],
[0.9152, 0.3971, 0.8742]], device='cuda:0')
DEVICE: 0
x.size() = torch.Size([4, 3])
DEVICE: 1
x.size() = torch.Size([4, 3])
x:
tensor([[0.0223, 0.1689, 0.2939],
[0.5185, 0.6977, 0.8000],
[0.1610, 0.2823, 0.6816],
[0.9152, 0.3971, 0.8742]], device='cuda:1')
DEVICE 1 PRNG STATUS BEFORE:
46790328.0
DEVICE 1 PRNG STATUS AFTER:
46787832.0
DEVICE 1 DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM():
24.0
Delta:
tensor([[0.4194, 0.5529, 0.9527],
[0.0362, 0.1852, 0.3734],
[0.3051, 0.9320, 0.1759],
[0.2698, 0.1507, 0.0317]], device='cuda:1')
x:
tensor([[0.4963, 0.7682, 0.0885],
[0.1320, 0.3074, 0.6341],
[0.4901, 0.8964, 0.4556],
[0.6323, 0.3489, 0.4017]], device='cuda:0')
DEVICE 0 PRNG STATUS BEFORE:
46787832.0
DEVICE 0 PRNG STATUS AFTER:
46786488.0
DEVICE 0 DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM():
24.0
Delta:
tensor([[0.2081, 0.9298, 0.7231],
[0.7423, 0.5263, 0.2437],
[0.5846, 0.0332, 0.1387],
[0.2422, 0.8155, 0.7932]], device='cuda:0')
X out:
tensor([[0.7044, 1.6980, 0.8116],
[0.8744, 0.8337, 0.8777],
[1.0747, 0.9296, 0.5943],
[0.8745, 1.1644, 1.1949],
[0.4417, 0.7218, 1.2466],
[0.5547, 0.8829, 1.1734],
[0.4661, 1.2143, 0.8575],
[1.1850, 0.5478, 0.9059]], device='cuda:0')
run 2:
X in:
tensor([[0.4963, 0.7682, 0.0885],
[0.1320, 0.3074, 0.6341],
[0.4901, 0.8964, 0.4556],
[0.6323, 0.3489, 0.4017],
[0.0223, 0.1689, 0.2939],
[0.5185, 0.6977, 0.8000],
[0.1610, 0.2823, 0.6816],
[0.9152, 0.3971, 0.8742]], device='cuda:0')
DEVICE: 0
x.size() = torch.Size([4, 3])
DEVICE: 1
x.size() = torch.Size([4, 3])
x:
tensor([[0.0223, 0.1689, 0.2939],
[0.5185, 0.6977, 0.8000],
[0.1610, 0.2823, 0.6816],
[0.9152, 0.3971, 0.8742]], device='cuda:1')
x:
tensor([[0.4963, 0.7682, 0.0885],
[0.1320, 0.3074, 0.6341],
[0.4901, 0.8964, 0.4556],
[0.6323, 0.3489, 0.4017]], device='cuda:0')
DEVICE 0 PRNG STATUS BEFORE:
46790328.0
DEVICE 1 PRNG STATUS BEFORE:
46790328.0
DEVICE 0 PRNG STATUS AFTER:
46786488.0
DEVICE 1 PRNG STATUS AFTER:
46786488.0
DEVICE 0 DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM():
48.0
DEVICE 1 DIFF PRNG STATUS ABS(BEFORE - AFTER).SUM():
48.0
Delta:
tensor([[0.4194, 0.5529, 0.9527],
[0.0362, 0.1852, 0.3734],
[0.3051, 0.9320, 0.1759],
[0.2698, 0.1507, 0.0317]], device='cuda:0')
Delta:
tensor([[0.2081, 0.9298, 0.7231],
[0.7423, 0.5263, 0.2437],
[0.5846, 0.0332, 0.1387],
[0.2422, 0.8155, 0.7932]], device='cuda:1')
X out:
tensor([[0.9157, 1.3211, 1.0412],
[0.1682, 0.4927, 1.0075],
[0.7952, 1.8284, 0.6315],
[0.9021, 0.4996, 0.4334],
[0.2305, 1.0987, 1.0170],
[1.2609, 1.2240, 1.0437],
[0.7456, 0.3154, 0.8203],
[1.1574, 1.2126, 1.6673]], device='cuda:0')
These are the only possible outcomes. You see that there is only two unique deltas. Each one either generator by the first or the second thread depending on which call first torch.rand()
. The signature of the PRNG is not perfect is a sens that two DIFFERENT PRNGs may have the same signature. It seems that the PRNG status is designed in some specific way with some properties. I didn’t spend much time to find a perfect unique signature.
Imagine you call the forward function 100
times. Predicting the output is impossible since it will be random. Now, you understand why it is impossible to obtain reproducible results when using multithreadings.