Output from network varies depending on CPU or GPU

Hi,

I have come across a problem where despite setting random seeds, I obtain different outputs from a simple network depending on whether I use CPU or GPU. I also receive different CPU results using different computers but receive the same GPU results.

The model weights are equal as is the randomly generated input “a”, however differences occur after passing “a” through the network. I also get the same results using model.eval() or model.train().

So my 2 questions are:

  1. Why are the results different when using CPU vs GPU even though random seeds have been set?
  2. Probably linked to the first but why are the CPU results across computers different?

My results using:
System details: intel i5-7500, GTX 1060 6GB, 32GB RAM, Samsung Evo 850 500GB
nvidia-smi: 435.21

output[0, :]: [[-0.18105512857437134], [-0.9114810228347778], [-0.5673332810401917], [-1.0145820379257202]]
output[0, :]: [[-0.18105511367321014], [-0.9114810228347778], [-0.5673332214355469], [-1.0145819187164307]]

When I run this code on another computer using:
Intel® Xeon® Gold 6134 CPU @ 3.20GHz, GTX 1080 Ti
nvidia-smi: 430.50

output[0, :]: [[-0.18105512857437134], [-0.9114810228347778], [-0.5673332810401917], [-1.0145820379257202]]
output[0, :]: [[-0.18105512857437134], [-0.9114810228347778], [-0.5673332214355469], [-1.0145819187164307]]

import torch
import torch.nn as nn
import numpy
import random
import math


class SimpleModel1(nn.Module):
    def __init__(self):
        super(SimpleModel1, self).__init__()
        self.hidden = nn.Linear(in_features=6,
                                out_features=1)
        self.init_weights()

    def init_weights(self):
        self.init_layer(self.hidden)

    def init_layer(self, layer):
        (n_out, n) = layer.weight.size()
        std = math.sqrt(2. / n)
        scale = std * math.sqrt(3.)
        layer.weight.data.uniform_(-scale, scale)

        if layer.bias is not None:
            layer.bias.data.fill_(0.)

    def forward(self, x):
        x = self.hidden(x)
        return x


if __name__ == '__main__':
    model_weights = []
    pre_a = []
    post_a = []
    for i in range(2):
        chosen_seed = 0
        torch.manual_seed(chosen_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.cuda.manual_seed_all(chosen_seed)
        numpy.random.seed(chosen_seed)
        random.seed(chosen_seed)
       
        # (batch, time_frame, feature)
        a = torch.rand(2, 4, 6)
        model = SimpleModel1()

        if i == 0:
            model.cuda()
            a = a.cuda()
        test_model1 = []
        test_a_1 = a
        for params in model.parameters():
            test_model1.append(params.data)

        model.train()

        output = model(a)
        test_model2 = []
        for params in model.parameters():
            test_model2.append(params.data)
        print(f"output[0, :]: {output[0, :, :].tolist()}")
        model_weights.append([test_model1[0], test_model2[0]])
        pre_a.append(test_a_1)
        post_a.append(output)

    for i in range(2):
        if torch.all(torch.eq(model_weights[0][i].cpu(), model_weights[1][i])):
            print('Model Weights OK')
        else:
            print('Error in Weights')
    if torch.all(torch.eq(pre_a[0].cpu(), pre_a[1])):
        print('Pre-model inputs OK')
    else:
        print('Error in pre_model inputs')
    if torch.all(torch.eq(post_a[0].cpu(), post_a[1])):
        print('Post-model inputs OK')
    else:
        print('Error in post_model inputs')

My environment details:

name: repro
channels:
  - pytorch
  - anaconda
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - blas=1.0=mkl
  - ca-certificates=2019.11.27=0
  - certifi=2019.11.28=py37_0
  - cudatoolkit=10.1.243=h6bb024c_0
  - freetype=2.9.1=h8a8886c_1
  - intel-openmp=2019.5=281
  - jpeg=9b=h024ee3a_2
  - libedit=3.1.20181209=hc058e9b_0
  - libffi=3.2.1=hd88cf55_4
  - libgcc-ng=9.1.0=hdf63c60_0
  - libgfortran-ng=7.3.0=hdf63c60_0
  - libpng=1.6.37=hbc83047_0
  - libstdcxx-ng=9.1.0=hdf63c60_0
  - libtiff=4.1.0=h2733197_0
  - mkl=2019.5=281
  - mkl-service=2.3.0=py37he904b0f_0
  - mkl_fft=1.0.15=py37ha843d7b_0
  - mkl_random=1.1.0=py37hd6b4f25_0
  - ncurses=6.1=he6710b0_1
  - ninja=1.9.0=py37hfd86e86_0
  - numpy=1.18.1=py37h4f9e942_0
  - numpy-base=1.18.1=py37hde5b4d6_1
  - olefile=0.46=py37_0
  - openssl=1.1.1=h7b6447c_0
  - pillow=7.0.0=py37hb39fc2d_0
  - pip=19.3.1=py37_0
  - python=3.7.3=h0371630_0
  - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0
  - readline=7.0=h7b6447c_5
  - setuptools=44.0.0=py37_0
  - six=1.13.0=py37_0
  - sqlite=3.30.1=h7b6447c_0
  - tk=8.6.8=hbc83047_0
  - torchvision=0.5.0=py37_cu101
  - wheel=0.33.6=py37_0
  - xz=5.2.4=h14c3975_4
  - zlib=1.2.11=h7b6447c_3
  - zstd=1.3.7=h0b5b093_0

The differences you are seeing are approx. 1e-6 which comes most likely due to the limited floating point precision of float32.
The order of operations might yield different results as seen here:

x = torch.randn(10, 10, 10)
s1 = x.sum()
s2 = x.sum(0).sum(0).sum(0)
print((s1 - s2).abs().max())
> tensor(3.8147e-06)

I assume the difference in CPU architecture between both posted CPUs might explain the difference, but that’s just by best guess.

Also, the implementation of the pseudo-random number generator might be different for different hardware devices (e.g. CPU vs. GPU).

Thanks for your answer, I tried increasing model complexity and printing out the maximum difference between GPU and CPU results and I think you are correct, the highest difference I received was 2.3543834686279297e-06.

Architecture difference is also a good guess.