Reproducibility issue between Intel and AMD CPUs

I am using torch==1.8.1+cpu and transformers==4.7.0 in a python 3.6 docker container installed using pip.
Here is the build info

PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.7.0 (Git Hash 7aed236906b1f7a05c0917e5257a1af05e9ff683)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.8.1, USE_CUDA=0, USE_CUDNN=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, 

With this, I am getting different results for the same model forward on an Intel CPU and an AMD CPU. I have read about non deterministic ops but none of them seem to be matching my situation

base_model = 'google/bert_uncased_L-12_H-256_A-4'

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import torch
import numpy as np
import random
from transformers import BertConfig, BertTokenizerFast, BertForSequenceClassification

torch.set_printoptions(precision=11)
print(torch.cuda.is_available())
def set_random_seed(seed_value: int, use_cuda: bool = False):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    random.seed(seed_value)
    torch.use_deterministic_algorithms(True)
    if use_cuda:
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True


set_random_seed(2019, use_cuda=False)
tokenizer = BertTokenizerFast.from_pretrained(base_model, do_lower_case=True, max_len=256)
bert_config = BertConfig.from_pretrained(base_model, num_labels=2)
model = BertForSequenceClassification.from_pretrained(base_model, config=bert_config)
batch_text_or_text_pairs = [['hello', 'world']]
all_label_ids = [1]


inputs = tokenizer.batch_encode_plus(
    batch_text_or_text_pairs=batch_text_or_text_pairs,
    add_special_tokens=True,
    padding='max_length',
    truncation='longest_first',
    max_length=256,
    stride=0,
    is_split_into_words=False,
    pad_to_multiple_of=None,
    return_tensors='pt',
    return_token_type_ids=True,
    return_attention_mask=True,
    return_overflowing_tokens=False,
    return_special_tokens_mask=False,
    return_offsets_mapping=False,
    return_length=False,
    verbose=False,
)

inputs['labels'] = torch.tensor(all_label_ids, dtype=torch.long)
model.zero_grad()
model = model.train()
model(**inputs)

On two Intel CPUs

Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
Intel(R) Core(TM) i7-8565U CPU @ 1.80GHz

this consistently gives

SequenceClassifierOutput(loss=tensor(0.53043878078, grad_fn=<NllLossBackward>), logits=tensor([[-0.31894901395,  0.03818603605]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

on AMD CPUs

AMD EPYC 7V12 64-Core Processor

this consistently gives

SequenceClassifierOutput(loss=tensor(0.56962001324, grad_fn=<NllLossBackward>), logits=tensor([[-0.26781406999, -0.00332138361]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

Should I compile torch from source using Eigen or OpenBLAS for reproducibility? Or is something else going on that I can’t control?
My libtorch*.so does not link against libmkl

./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so:
	linux-vdso.so.1 (0x00007ffc05543000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007f9a89df6000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f9a89c73000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007f9a89a49000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007f9a89a2f000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f9a89a0e000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f9a8984d000)
	/lib64/ld-linux-x86-64.so.2 (0x00007f9a8a21d000)
./usr/local/lib/python3.6/site-packages/torch/lib/libcaffe2_detectron_ops.so:
	linux-vdso.so.1 (0x00007fff351c0000)
	libtorch.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch.so (0x00007f5448580000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007f5448356000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f544832b000)
	libtorch_cpu.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x00007f5436c5e000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007f54369c7000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f5436844000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f543683d000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007f54366b9000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007f543669f000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f54364de000)
	/lib64/ld-linux-x86-64.so.2 (0x00007f544a508000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007f54364d4000)
./usr/local/lib/python3.6/site-packages/torch/lib/libcaffe2_module_test_dynamic.so:
	linux-vdso.so.1 (0x00007fff5618a000)
	libtorch.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch.so (0x00007f18e8fa5000)
	libtorch_cpu.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x00007f18d78d8000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007f18d7641000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007f18d74b3000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007f18d7499000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f18d72d8000)
	/lib64/ld-linux-x86-64.so.2 (0x00007f18e93c5000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007f18d70ac000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f18d708b000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007f18d7081000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f18d707c000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f18d6ef9000)
./usr/local/lib/python3.6/site-packages/torch/lib/libcaffe2_observers.so:
	linux-vdso.so.1 (0x00007ffc705cf000)
	libtorch.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch.so (0x00007fe0e9fc0000)
	libtorch_cpu.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x00007fe0d88f3000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007fe0d865c000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007fe0d8631000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007fe0d84ad000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007fe0d8493000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007fe0d82d0000)
	/lib64/ld-linux-x86-64.so.2 (0x00007fe0ea437000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007fe0d80a6000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007fe0d809c000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007fe0d8097000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007fe0d7f14000)
./usr/local/lib/python3.6/site-packages/torch/lib/libjitbackend_test.so:
	linux-vdso.so.1 (0x00007fff49fad000)
	libtorch.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch.so (0x00007fde84056000)
	libtorch_cpu.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x00007fde72989000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007fde726f2000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007fde72564000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007fde7254a000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007fde72389000)
	/lib64/ld-linux-x86-64.so.2 (0x00007fde8448a000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007fde7215d000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007fde7213c000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007fde72132000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007fde7212d000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007fde71faa000)
./usr/local/lib/python3.6/site-packages/torch/lib/libshm.so:
	linux-vdso.so.1 (0x00007ffc4e5fa000)
	libtorch.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch.so (0x00007fa1936a9000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007fa193695000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007fa1933fe000)
	libtorch_cpu.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x00007fa181d31000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007fa181d10000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007fa181b8c000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007fa181b70000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007fa1819af000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007fa18182c000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007fa181602000)
	/lib64/ld-linux-x86-64.so.2 (0x00007fa193ac6000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007fa1815fd000)
./usr/local/lib/python3.6/site-packages/torch/lib/libtorchbind_test.so:
	linux-vdso.so.1 (0x00007ffef2bef000)
	libtorch.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch.so (0x00007f5ec4810000)
	libtorch_cpu.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x00007f5eb3143000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007f5eb2eac000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007f5eb2d1e000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007f5eb2d04000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f5eb2b43000)
	/lib64/ld-linux-x86-64.so.2 (0x00007f5ec4c7a000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007f5eb2917000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f5eb28f6000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007f5eb28ec000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f5eb28e7000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f5eb2764000)
./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so:
	linux-vdso.so.1 (0x00007ffd9bdf3000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007fc91be38000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007fc91be0d000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007fc91be03000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007fc91bde9000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007fc91bde4000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007fc91bc61000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007fc91b9c8000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007fc91b844000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007fc91b683000)
	/lib64/ld-linux-x86-64.so.2 (0x00007fc92d731000)
./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_global_deps.so:
	linux-vdso.so.1 (0x00007ffc720d9000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007f8f8f1f3000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007f8f8f070000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007f8f8f06b000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007f8f8ee41000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007f8f8ec80000)
	/lib64/ld-linux-x86-64.so.2 (0x00007f8f8f423000)
./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_python.so:
	linux-vdso.so.1 (0x00007ffc78bb5000)
	libshm.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libshm.so (0x00007fcd33fbb000)
	libtorch.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch.so (0x00007fcd33da7000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007fcd33d93000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007fcd33afc000)
	libtorch_cpu.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x00007fcd2242f000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007fcd2240e000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007fcd22407000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007fcd22283000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007fcd22269000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007fcd220a8000)
	/lib64/ld-linux-x86-64.so.2 (0x00007fcd35274000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007fcd21f25000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007fcd21cfb000)
./usr/local/lib/python3.6/site-packages/torch/lib/libtorch.so:
	linux-vdso.so.1 (0x00007ffcb1d87000)
	libtorch_cpu.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so (0x00007fcbee57e000)
	libgcc_s.so.1 => /lib/x86_64-linux-gnu/libgcc_s.so.1 (0x00007fcbee55a000)
	libc.so.6 => /lib/x86_64-linux-gnu/libc.so.6 (0x00007fcbee399000)
	libgomp-a34b3233.so.1 => /./usr/local/lib/python3.6/site-packages/torch/lib/libgomp-a34b3233.so.1 (0x00007fcbee16f000)
	libpthread.so.0 => /lib/x86_64-linux-gnu/libpthread.so.0 (0x00007fcbee14e000)
	librt.so.1 => /lib/x86_64-linux-gnu/librt.so.1 (0x00007fcbee144000)
	libdl.so.2 => /lib/x86_64-linux-gnu/libdl.so.2 (0x00007fcbee13d000)
	libm.so.6 => /lib/x86_64-linux-gnu/libm.so.6 (0x00007fcbedfba000)
	libc10.so => /./usr/local/lib/python3.6/site-packages/torch/lib/libc10.so (0x00007fcbedd23000)
	libstdc++.so.6 => /usr/lib/x86_64-linux-gnu/libstdc++.so.6 (0x00007fcbedb9f000)
	/lib64/ld-linux-x86-64.so.2 (0x00007fcbffe61000)
1 Like

I don’t think there is a guarantee to get bitwise identical results using different hardware.
Enabling deterministic results makes sure to get the same values on the currently used system/setup.

Ah, that’s disappointing :frowning:

I tried three more things

  1. Predicting with an already trained model - the confidence scores are not exact but the precision is still very high (absolute difference <= 1e-6)
  2. Just changing model.train() to model.eval() - Same as 1, not exact but still very high precision
  3. Checking pseudo random generation by using torch.rand() after setting seed. The results here were exact

I wonder what operations in this training/model architecture scenario make the difference so large. Since putting the model in train mode causes this I am going to check the outputs layer by layer once - maybe something is up with dropout, batch norm, layer norm layers

That’s a good idea as I wouldn’t entirely exclude the possibility of a numerical issue in one of your setups (e.g. caused in the framework directly, by a 3rd party library etc.).

So I am back after some debugging. Here are my findings:

Dropout kernels produce different masks across different CPUs (Intel vs AMD) and GPUs (V100 vs T4). They only produce the same mask when run on the exact same device.

LayerNorm CPU and CUDA kernels don’t produce the exact same output but at least the output is close enough - CPU outputs are exactly matching when comparing across different cpus, GPU outputs are exactly matching when comparing across different GPUs.

My minimal reproduction code along with runs from three different machines is here: GitHub - chiragjn/torch-1.8.1-reproducibility-bug

In summary

torch.nn.LayerNorm

LayerNorm (train mode) Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz AMD EPYC 7V12 64-Core Processor Intel(R) Core™ i7-8565U CPU @ 1.80GHz Tesla V100-PCIE-16GB, 460.27.04 Tesla T4, 460.27.04
Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz Exact - - - -
AMD EPYC 7V12 64-Core Processor Exact Exact - - -
Intel(R) Core™ i7-8565U CPU @ 1.80GHz Exact Exact Exact - -
Tesla V100-PCIE-16GB, 460.27.04 Close enough Close enough Close enough Exact -
Tesla T4, 460.27.04 Close enough Close enough Close enough Exact Exact

torch.nn.Dropout

Dropout (train mode) Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz AMD EPYC 7V12 64-Core Processor Intel(R) Core™ i7-8565U CPU @ 1.80GHz Tesla V100-PCIE-16GB, 460.27.04 Tesla T4, 460.27.04
Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz Exact - - - -
AMD EPYC 7V12 64-Core Processor Different Exact - - -
Intel(R) Core™ i7-8565U CPU @ 1.80GHz Exact Different Exact - -
Tesla V100-PCIE-16GB, 460.27.04 Different Different Different Exact -
Tesla T4, 460.27.04 Different Different Different Different Exact

We can live with LayerNorm, but dropout masks being different might result in significantly different models after quite some number of batches especially when dropout is used almost after every layer :upside_down_face:

Thanks for the update. The pseudorandom number generators would use different implementations which is why random operations such as dropout are not producing the same “random” numbers.
Yes, I think you are right that “different” models would be learned, but as long as the random sampling process is not broken, I would expect to see a similar model performance.
If the success of the model training itself depends on a specific random seed, it would be a bad sign for the stability of the overall training routine.

While I agree, I can’t write any automated tests to check training routines and have them pass on any virtual machine.

But apart from that, I am wondering how come torch.rand and a few other randomization ops can be made deterministic across all kinds of devices but dropout can’t be? :thinking:

I will try to dig into the kernel code myself when I get the time later but would love an answer if someone is already familiar with the code

I don’t think that’s the case. At least I see the expected differences:

torch.manual_seed(2809)
print(torch.randn(10))
# > tensor([-2.0748,  0.8152, -1.1281,  0.8386, -0.4471, -0.5538, -0.8776, -0.5635,
#           0.5434, -0.8192])

torch.manual_seed(2809)
print(torch.randn(10, device='cuda'))
# > tensor([ 0.5603,  1.8841, -0.1417, -0.6460, -2.2665, -1.3027,  0.3552, -0.5376,
#           0.1375,  1.6309], device='cuda:0')
2 Likes

Ah, I see, in my experiments, I have always been creating my input on CPU and then moving them to my test device. Fair enough I won’t expect the same output when comparing CPU vs GPU

But still, torch.rand works (or at least so far seems to) exactly the same when I move from Intel to AMD and/or V100 to T4 but torch.nn.Dropout definitely doesn’t :sweat_smile:

Might have to read hardware manuals now to check RNG implementations

Just leaving this here, the Bernoulli kernels are indeed customized for Intel:
on Intel CPUs, vectorized RNG from mkl would be used and on AMD it falls back to default one

For CUDA I am not sure what can be done but the RNG for some reason does not behave the same across GPUs

def check(seed, device, size):
    import torch
    torch.set_printoptions(precision=11)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)
    if 'cuda' in device:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True
    device = torch.device(device)
    x = torch.rand(*size)
    x = x.to(device)
    print(x)
    a = x.bernoulli_(p=0.5)
    print(a.sum(axis=0))
    print(a.sum(axis=1))

check(42, 'cuda', (256, 256))

For some reason, the difference doesn’t occur until a particular size is hit (more than ~60K elements in my observations). Now I don’t know if this is torch incorrectly using the curand APIs or a bug within CUDA itself :confused: