Inconsistent evaluation between CPU and GPU

Hi.

Me and another colleague have been struggling on this problem for the last days and we are losing our sanity. Unfortunately the conditions to recreate this problems are so rare and fragile that I have to attach a ~600Mb data file to this post for people interested to reproduce this. This data file does not contain any personal or sensible data, it’s just some random Monte Carlo numbers generated by another program.

Anyway here is the problem:

We have a really simple model: 7 input neurons, going to 5 and then a single output. Everything is made with linear layer and with sigmoid activations. In the example below, the model doesn’t even need to be trained for this problem to appear.

When we evaluate this model on some large datasets (~10 Million points), there are some inconsistencies between the CPU and GPU evaluations. In particular, the CPU is always right but the GPU starts giving some random results after about 2 Million points.

Here are some plots that show this behavior:

  • histogram of the difference between cpu evaluations and gpu evaluations. You can see a peak at 0 where the two evaluations coincide

  • plot of the same difference. You can see that the two evaluations coincide for the first 2 Million points and then it becomes noise.

If you want to reproduce these plots here is a minimal working code:

import torch, os, h5py
from torch import nn
import matplotlib.pyplot as plt

class Model(nn.Module):    
    def __init__(self, input_dim, output_dim, hidden_units):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_units, bias = False)
        self.fc2 = nn.Linear(hidden_units, output_dim, bias = False)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

#create the model
model = Model(7, 1, 5)

#read the data file
data = h5py.File(os.getcwd() + '/data.h5', 'r')
test_input = torch.Tensor((data['Data'])[()])

#shuffle the data
test_input = test_input[torch.randperm(test_input.size(0))]

#get indices for which variable 6 is between 350 and 450
idx1 = test_input[:, 6] > 350
idx2 = test_input[:, 6] < 450

#normalize data
norm_test_input = (test_input - test_input.mean(0))/test_input.std(0)

#take only data between 350 and 450
test_input_cut = norm_test_input[idx1 * idx2]

print(test_input_cut.shape)

# --- #

#evaluate the model on CPU
print("CPU")
with torch.no_grad():
    print(model(test_input_cut)[-10:])
    print(model(test_input_cut[-10:]))

    cpu_results = model(test_input_cut)

#evaluate the model on GPU
print("GPU")
model.cuda()    
test_input_cut = test_input_cut.cuda()
with torch.no_grad():
    print(model(test_input_cut)[-10:])
    print(model(test_input_cut[-10:]))

    gpu_results = model(test_input_cut)
                     
#plot difference between the two evaluations       
diff = (cpu_results - gpu_results.cpu()).squeeze().numpy()
plt.hist(diff, bins = 30)
plt.show()

plx = range(cpu_results.size(0))
ply = (cpu_results.squeeze() - gpu_results.squeeze().cpu()).numpy()
plt.plot(plx, ply)
plt.show()

And here is the 600MB .h5 file (that needs to be put in the same folder of the python code) for which this problem appears.

This behavior is really fragile in this example: changing the two numbers (350 and 450) even slightly erases makes everything work fine. For different (or random generated) data we couldn’t find any “cut” that would make this appear, but from what we tried it doesn’t look like there are any problems with the input data. In the real case, i.e. in the full project code, this seems to happen more consistently. We can’t post the full source code, so this tiny fragile example is the only thing that we could come up with to reproduce the issue in a consistent way to be posted in a forum.

Thanks everybody in advance for your help.

This seems to be an interesting problem, but unfortunately I couldn’t reproduce this issue.
I used your code and the provided data, but the differences on my system seem to be alright.
res1
res2

Also, printing the unique values for diff gives:

np.unique(diff, return_counts=True)
(array([-1.1920929e-07, -5.9604645e-08,  0.0000000e+00,  5.9604645e-08,
         1.1920929e-07], dtype=float32),
 array([    765,  441331, 3331571,  101474,       2]))

which is in the expected floating point precision range.

I’m using a 1080Ti, 418.56 driver and CUDA10.1.

Hi thanks for your answer. We tested this on two different machines and we get the same result.

Machine 1:

GPU: GTX 1070
Pytorch: 1.1.0
Cuda: 9.0.36
Driver: 430.14

Machine 2:

GPU: GTX 1050
Pytorch: 1.0.1
Cuda: 9.0
Driver: 430.86

It’s weird that you don’t find the same. Could it be related to CUDA 9 vs 10?

I’m not sure, but it would be interesting to see, if updating to CUDA10 helps.
Would that be possible or are you stuck to CUDA9 on your machines?

I have updated the GTX1050 machine to pytorch 1.1.0 and CUDA 10 and now the problems seem not to be there, neither in the example or in the full code. I wonder what was the problem then.

I could reproduce this issue on my system with CUDA9.0.
I also debugged a little bit and found that the GPU outputs a constant value after 2097152 input values, which is exactly 2**21.

print(gpu_results[2**21:].shape)
> torch.Size([1777991, 1])
print(gpu_results[2**21:].unique())
> tensor([0.5127], device='cuda:0', grad_fn=<NotImplemented>)

I’m not sure, what’s going on, but maybe @ngimel might know some limitation regarding this magic number and CUDA9.

1 Like

2**21 number makes me think that you are hitting https://github.com/pytorch/pytorch/pull/22034. The solution is to update to cuda 9.2

1 Like