Nondeterminism even when setting all seeds, 0 workers, and cudnn.deterministic

There are several other posts on the forums about nondeterminism. Most of the answers suggest some combination of setting all the random seeds, setting num_workers to 0, and setting backends.cudnn.deterministic = True. I have tried all of the above, and I still get non-deterministic final accuracy numbers. Here is a short script that demonstrates this behavior: a simple convolutional neural network on CIFAR-10. I ran it 4 times in a row (same machine) and got 3 different final accuracies: 64.31, 64.72, 64.67. (Perhaps these aren’t huge differences, but with the real architecture and dataset I am using, I’ve observed final differences over 1% when running the same code, which is quite a bit in my opinion.) I am using Python 3.6.3, PyTorch 0.4.1, and a single NVIDIA Tesla K80 GPU, on CentOS Linux 7.4.1708.

import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(3, 64, 5, padding=1)
    self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
    self.conv4 = nn.Conv2d(128, 128, 3, padding=1)
    self.fc1 = nn.Linear(128 * 7 * 7, 1024)
    self.fc2 = nn.Linear(1024, 10)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.pool(self.conv2(x)))
    x = F.relu(self.conv3(x))
    x = F.relu(self.pool(self.conv4(x)))
    x = x.view(-1, 128 * 7 * 7)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, 0.2, training=self.training)
    x = self.fc2(x)
    return x


##### Get the data.
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, pin_memory=True,
                                           shuffle=True, num_workers=0,
                                           worker_init_fn=lambda x: np.random.seed(1))
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False, num_workers=0)

##### Set the seeds.
torch.manual_seed(2)
torch.cuda.manual_seed_all(2)
torch.backends.cudnn.deterministic = True
np.random.seed(1)
random.seed(1)
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

##### Train the model.
model = CNN()
model.to(device)
model = nn.DataParallel(model) # Not needed for just one GPU, but removing this doesn't fix the nondeterminism

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(2):
  print('Epoch %d' % epoch)
  for inputs, labels in train_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

##### Evaluate the model.
correct, tot = 0, 0
model.eval()
with torch.no_grad():
  for inputs, labels in test_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)
    _, predicted = torch.max(outputs.data, 1)
    tot += labels.size(0)
    correct += (predicted == labels).sum().item()

print(correct / tot * 100)

Your issue might be related to this discussion.
Basically, even if you set torch.backends.cudnn.deterministic = True, torch.backends.cudnn.benchmark = True might select different deterministic algorithms, which would yield different results.
There is a current PR working on this.

Could you disable benchmarking and try it again?

This worked (gave deterministic results). Thank you!

Hi, I am also facing a similar issue

 torch.backends.cudnn.deterministic=True
 torch.backends.cudnn.benchmark=False

I do just after the imports, but I am still getting non deterministic behaviour. I have also tried torch.backends.cudnn.enabled=False as I read that cunn modules also provide deterministic behaviour.

I have shuffle turned off, model on eval mode, and backpropogating for the gradients of input image (for performing adversarial attacks), so I believe protocol is fulfilled ?
Link to thread I created (which further contain links to questions I created)

Also please look at https://github.com/pytorch/pytorch/issues/12207