Determinism in inference

I have the following code snippet. When I run the script consecutively, I get different results in prediction. Is there a setting I’m missing?

import torch
import numpy as np
import random

# Set the seed for all random number generators
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set CUBLAS workspace config to ensure deterministic cuBLAS
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

torch.use_deterministic_algorithms(True)
torch.set_num_threads(1)

# Load the model
model = torch.jit.load('test.pt')
model = model.to(torch.device('cuda:0'))
model = model.half()

# Create a random array
torch.manual_seed(0)
random_arr = torch.randn(1,1,224, 224, 224, dtype = torch.half).to(torch.device('cuda:0'))
pred = model(random_arr).to(torch.device('cpu')).detach().numpy()

The pred variable is different in each run, max. difference is about 3e-2. Any help would be appreciated. Thanks!

Could you post a minimal and executable code snippet reproducing the issue?

Thanks for your response. Here are two scripts, please save them as script_1.py and script_2.py in the working directory. For this experiment, I’m using a model from huggingface, but the results are similar with my own segmentation model (nnU-Net).

# Save this as script_1.py

import torch
import numpy as np
import random
import requests
import pickle
import argparse

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--result_filename", help="location to save the result prediction")

    args = parser.parse_args()
    
    # Set the seed for all random number generators
    seed = 42
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set CUBLAS workspace config to ensure deterministic cuBLAS
    import os
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

    torch.use_deterministic_algorithms(True)
    torch.set_num_threads(1)

    # Download the file
    url = 'https://huggingface.co/spaces/Xajimel/Practica3/resolve/main/unet.pth'
    response = requests.get(url)
    with open('model.pt', 'wb') as f:
        f.write(response.content)

    model = torch.jit.load('model.pt')
    model = model.to(torch.device('cuda:0'))
    model = model.half()

    # Create a random array
    torch.manual_seed(0)
    random_arr = torch.randn(1, 3, 224, 224, dtype = torch.half).to(torch.device('cuda:0'))
    pred = model(random_arr).to(torch.device('cpu')).detach().numpy()

    with open(args.result_filename, 'wb') as f:
        pickle.dump(pred, f)
# Save this as script_2.py

import pickle
import numpy as np

with open('pred1.pkl', 'rb') as f:
    a = pickle.load(f)

with open('pred2.pkl', 'rb') as f:
    b = pickle.load(f)

print('Length of a: ', len(a))
print('Length of b: ', len(b))

print('Shape of first element in a: ', a[0].shape)
print('Shape of first element in b: ', b[0].shape)

diff = a[0] - b[0]

print('Min diff: ', np.min(diff))
print('Max diff: ', np.max(diff))

Run this with the following commands.

python script_1.py --result_filename pred1.pkl
python script_1.py --result_filename pred2.pkl
python script_2.py

Something I found interesting was that if I change the url in script_1.py to a different model (say this one), I see no differences in the prediction.

I wonder if the issue then is non-determinism in some layers in the nnU-Net model or the way I’m saving the model. This is how I save the model.

self.network = self.network.to(self.device)
self.network.eval()
torch.use_deterministic_algorithms(True)
traced_model = torch.jit.trace(self.network, torch.randn(1, 1, 224, 224, 224).to(self.device))
torch.jit.save(traced_model, 'model.pt')

Let me know what you think or if you need some more info. Thanks!

I cannot reproduce the issue and see:

Length of a:  1
Length of b:  1
Shape of first element in a:  (5, 224, 224)
Shape of first element in b:  (5, 224, 224)
Min diff:  0.0
Max diff:  0.0

This is what I see. I’m on torch 2.3.0+cu121 if that helps.

Length of a:  1
Length of b:  1
Shape of first element in a:  (5, 224, 224)
Shape of first element in b:  (5, 224, 224)
Min diff:  -0.01563
Max diff:  0.01172

@ptrblck On the machine that I was initially testing on, I see the issue on torch 2.3.0+cu121 and torch 2.4.0+cu121. On another machine, I don’t see the issue in either versions. These machines have different GPUs (earlier: A40, now: RTX 3050 Ti).

Could the underlying GPU make a difference?

The problem with reproducibility for nnunet has been observed before

At this stage, I am not aware of a solution.