Tensor indexing can be incorrect in PyTorch CUDA

I am trying to set several positions of an array in PyTorch for CUDA but the result is different from analogue operation in Numpy. Please see posted code.
As an illustration, I have a 1D array c and want to set it’s i positions to the values of r. E.g. if c = [3, 4 , 5, 6], i = [1,3], r = [10, 11] I do c[i]=r and c becomes c=[3,10,5,11].
While this works fine for small arrays in PyTorch CUDA, it can be incorrect for large arrays. Please see posted code. The CUDA result is different from Numpy result. It shouldn’t.

import torch
import numpy as np
print("Numpy version", np.__version__)
print("Troch version", torch.__version__)
print("Torch cuda version", torch.version.cuda)
np.random.seed(0)
c = np.random.randint(10000, size=100000, dtype='l')
i = np.random.randint(c.shape[0], size=10000, dtype='l')
s = np.random.randint(10000, size=10000, dtype='l')
c_gpu = torch.cuda.LongTensor(c)
i_gpu = torch.cuda.LongTensor(i)
s_gpu = torch.cuda.LongTensor(s)
c[i]=s
c_gpu[i_gpu]=s_gpu
print(np.sum(c != c_gpu.cpu().data.numpy()))

Actual output:

Numpy version 1.16.3
Troch version 1.1.0
Torch cuda version 10.0.130
208

Of course in every execution I got a different number 208, 225, 301 where the expected output always is zero.
Any idea why? Is this a bug, or am I missing something? Any help appreciated. Thank you.

2 Likes

I think the reason for this non-deterministic behavior on CUDA tensors is, because you are using repeating indices in i and thus also in i_gpu (you can check it with i_gpu.unique().nelement()).
Since the operation is performed in parallel on the GPU, overlapping indices will take the value of the last kernel which writes at this position.
You can also trigger this behavior with a single value:

c = torch.zeros(1, dtype=torch.long)
i = torch.zeros(10000, dtype=torch.long)
s = torch.arange(10000)

c[i] = s  # On CPU will take last value (9999)
print(c)

c_gpu = torch.zeros(1, dtype=torch.long, device='cuda')
i_gpu = i.to('cuda')
s_gpu = s.to('cuda')

c_gpu[i_gpu] = s_gpu
print(c_gpu)  # Prints "random" values

I’m not sure if this is a bug (and the values should be the last value in s) or if this is just undefined behavior.

1 Like

Should n’t this follow the policy or behaviour we have in Numpy or at least CPU Torch.The operations in Numpy and Torch CPU are also in parallel but we only have more parallelism in GPU compared to CPU. Torch in CPU at least with those experiments that I have preformed results the same as Numpy.

Generally, yes. I would assume the GPU implementation show the same behavior as the CPU one.

However, are we forcing the current behavior on the CPU (or is it defined somewhere else) or are we just “lucky” to get the same result for overlapping indices?

Also, thanks for bringing up this issue. :wink:

CC @colesbury, @ngimel

Not sure how this implemented in Torch CUDA, but I believe a combination of Parallelism and Optimization should play their role to get us the fastest and at the same time the correct answer. Since the cpu time for this is quite fast I am sure the optimization does its job for overlapping indices either for Torch (CPU) or Numpy.
In my view if overlapping indices are allowed for Torch (CUDA) the behaviour should be deterministic and the same as CPU or it should raise run-time errors.

The operation here is index_put with accumulate=False. The manual says " If accumulate is False , the behavior is undefined if indices contain duplicate elements."
https://pytorch.org/docs/stable/tensors.html?highlight=index_put#torch.Tensor.index_put_. Same is true for numpy:
“For advanced assignments, there is in general no guarantee for the iteration order. This means that if an element is set more than once, it is not possible to predict the final result.”
https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html

1 Like