Manual seed cannot make dropout deterministic on CUDA for Pytorch 1.0 preview version

The testing code is very simple:

import torch
import torch.nn as nn
seed = 1
model= nn.Dropout(0.5)
use_cuda = True

for i in range(3):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    model.train()
    data = torch.randn(4, 4)
    if i > 0:
        print(torch.equal(data, pre_data))
    pre_data = data
    if use_cuda:
        data = data.cuda()
    out = model(data)
    loss = out.sum()
    print(i, loss.item())

For torch-nightly-1.0.0.dev20181001, the output is:

0 -5.580026626586914
True
1 -6.234308242797852
True
2 -8.840182304382324

It means the random state is not entirely determined by the manual seed for dropout on CUDA (CPU operation is fine from my test). While the outputs of pytorch v0.4.1 are identical. I guess it is a bug of the preview version.

2 Likes

Hi,

If you’re using the gpu, you need to set the cuda random see as well with torch.cuda.manual_seed().

Oh, I thought torch.cuda.manual_seed() is merged into torch.manual_seed() since it works fine for Pytorch v0.4.1.

I don’t think this was ever a single function. I’ve never seen such change and I would really doubt such change would be reverted as it is very very not backward compatible.

Sorry to tell you that it cannot be fixed by adding torch.cuda.manual_seed(seed) in the loop, updated in the snippet.

I thought the same, but this information of reproducibility might be misleading:

You can use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)

Interesting, after looking at the code, torch.manual_seed does call torch.cuda.manual_seed_all() :slight_smile:

It has been there in one form or the other for quite a while: https://github.com/pytorch/pytorch/pull/1762
Last I heard was that it was improved to “lazily” seeds all GPUs instead of doing it at call time.

Apparently something doesn’t work with (cuda) manual_seed, though. To decide between dropout and seeding as a source of error, I did two things:

  • It manual_seed also doesn’t render bernoulli deterministic, so it’s not in dropout.
  • When using set_rng_state, you achieve reproducible random numbers.
import torch
import torch.nn as nn
seed = 1
use_cuda = True
l = (torch.cuda.get_rng_state())
print("manual_seed - different each time")
for i in range(3):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    lastl = l
    l = (torch.cuda.get_rng_state())
    print ((l==lastl).all())
    a = torch.bernoulli(torch.full((3,3), 0.5, device='cuda'))
    print (a)
print("set_rng_state - same each time")
for i in range(3):
    torch.cuda.set_rng_state(l)
    a = torch.bernoulli(torch.full((3,3), 0.5, device='cuda'))
    print (a)
    b = torch.nn.functional.dropout(torch.ones(3,3, device='cuda'))
    print (b)

If you mention me (t-vi on github) on a bug report, I’ll try to figure out what’s going wrong and produce a fix.

Best regards

Thomas

1 Like

Interestingly, changing the dropout layer to be inplace seems to make dropout deterministic.
I guess unrelated fix then?

Hm. I can’t see that inplace helps.
I think I have to correct my initial assessment and now thing that bernoulli is broken - dropout uses bernoulli. I think it has recently been moved to native, maybe something went wrong there. I’ll look a bit more. Edit: Turns out it was the seeding, which is now fixed in master.

There is a special fused kernel for the non-inplace version that reproduces the logic from bernouilli. So if the bernouilli logic is broken, this one might be as well. The fused kernel is here if you want to take a look at it.

So I did some digging: To me it looks like manual_seed, which ends up calling createGeneratorState only sets the MTGP32 states (=that of the RNG used in THC code), but not the seed + offset for Philox (=the RNG used in native code).

I filed a bug.

@wandering007 Thank you for sharing this with us and providing the testing code! A reproducing example is gold. This should be fixed soon. As a workaround I would recommend using torch.cuda.set_rng_state.

5 Likes