Weird behavior of torch seed

Hi All,

I am programming a transformer with Torch, on synthetic data, for a research project.

I want to obtain reproducibility and therefore set the seed. I nevertheless noticed a - to me - very weird behavior.
In the following snippet I get the expected behavior: model0 and model1 have exactly the same accuracy, and the experiment is repeatable on my machine.

from Synthetic_Transformer import Transformer
from Synthetic_Transformer import Train
import wandb

import torch
torch.manual_seed(0)

import random
random.seed(0)

import numpy as np
np.random.seed(0)

# ++ train model0 ++
wandb0 = wandb.init(...)
model0, test_loader, N, h = Train.train(wandb0)
wandb.finish()


# ++ train model1 ++
wandb1 = wandb.init(...)
model1, test_loader, N, h = Train.train(wandb1)
wandb.finish()

However, if as in the next snippet I add the line Transformer.make_model(), the two models don’t have same accuracy anymore, even though print(torch.initial_seed()) prints 0 everywhere, as expected. This function Transformer.make_model() defines a new transformer object (initialized with xavier), and should not mess with random seed at all.

from Synthetic_Transformer import Transformer
from Synthetic_Transformer import Train
import wandb

import torch
torch.manual_seed(0)

import random
random.seed(0)

import numpy as np
np.random.seed(0)

# ++ train model0 ++
wandb0 = wandb.init(...)
model0, test_loader, N, h = Train.train(wandb0)
wandb.finish()
print(torch.initial_seed())

#this line should not have an effect on the seed
model_dummy = Transformer.make_model(101, 101, N = 1, h = 1, residuals=False, bias = False)
print(torch.initial_seed())

# ++ train model1 ++
wandb1 = wandb.init(...)
model1, test_loader, N, h = Train.train(wandb = wandb1)
wandb.finish()

I found out that in order to get the expected behavior, I need to re-set the torch seed as follows.

from Synthetic_Transformer import Transformer
from Synthetic_Transformer import Train
import wandb

import torch
torch.manual_seed(0)

import random
random.seed(0)

import numpy as np
np.random.seed(0)

# ++ train model0 ++
wandb0 = wandb.init(...)
model0, test_loader, N, h = Train.train(wandb0)
wandb.finish()
print(torch.initial_seed())

#this line should not have an effect on the seed
model_dummy = Transformer.make_model(101, 101, N = 1, h = 1, residuals=False, bias = False)
print(torch.initial_seed())

#repeat these 2 lines to get expected behavior
import torch
torch.manual_seed(0)

# ++ train model1 ++
wandb1 = wandb.init(...)
model1, test_loader, N, h = Train.train(wandb = wandb1)
wandb.finish()

The Transformer.make_model() seems therefore to change the torch seed, even though this is not caught by print(torch.initial_seed()), and even though I could not find any possible explanation for the change of seed to happen.

Running on CPU, Intel-Mac, python 3.9.13, pytorch 1.14.0.

Thanks for your help.

I don’t think that’s the case and the seed would still be the same you’ve set.
However,

explains that you are calling into the pseudo-random number generator (PRNG), which will thus create the divergence.
Once seeded the PRNG will create the same sequence of random numbers for the same order of operations.
If you break this assumption and now add more random calls to your script the order/sequence of operation changes and you will not get the same values anymore.

Here is a small example:

torch.manual_seed(2809)
print(torch.randn(3))
# tensor([-2.0748,  0.8152, -1.1281])
print(torch.randn(3))
# tensor([ 0.8386, -0.4471, -0.5538])
print(torch.randn(3))
# tensor([-0.8776, -0.5635,  0.5434])

torch.manual_seed(2809)
print(torch.randn(3))
# tensor([-2.0748,  0.8152, -1.1281])
print(torch.randn(3))
# tensor([ 0.8386, -0.4471, -0.5538])
print(torch.randn(3))
# tensor([-0.8776, -0.5635,  0.5434])

torch.manual_seed(2809)
print(torch.randn(3))
# tensor([-2.0748,  0.8152, -1.1281])

# add new call into PRNG by initializing a new layer
layer = nn.Linear(1, 1)

print(torch.randn(3))
# tensor([0.8386, 0.2195, 1.3188]) # !!!
print(torch.randn(3))
# tensor([-1.6841,  0.9325, -1.8085]) # !!!
1 Like