Indexing, unexpected behavior in cuda

I am having a problem with indexing of torch tensors. It seems that I get the wrong results when I run the following code on the GPU, additionally, the result is different each time the the model is run.
The following line creates the problem:

tuss[j,:,self.coordlist[:]] += x[j,:,:]

in this test code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class polartospectral(nn.Module):
    def __init__(self,intpoints, coordlist, cpu=True):
        super(polartospectral, self).__init__()
        self.coordlist = coordlist
        self.cpu = cpu
        if not self.cpu:
            self.coordlist = self.coordlist.cuda()

    def forward(self, x):
        tuss = torch.zeros(len(x),100,100)
        if not self.cpu:
            tuss = tuss.cuda()
        for j in range(len(x)):
            tuss[j,:,self.coordlist[:]] += x[j,:,:]
        return tuss
    
    
coordlist = torch.tensor([np.random.randint(0,100) for i in range(100)],dtype=torch.long)
model_cpu = polartospectral(1, coordlist, cpu=True)
model_gpu = polartospectral(1, coordlist, cpu=False).cuda()
inp = torch.rand(1,100,100)*1000

uit1 =model_gpu(inp.cuda()).cpu().detach().numpy()[0,1]
uit2 =model_cpu(inp.cpu() ).cpu().detach().numpy()[0,1]
print("this should be zero: ")
print(uit1-uit2)
print("output of GPU: ")
print(uit1)
print("output of CPU: ")
print(uit2)

So when I run

uit1 =model_gpu(inp.cuda()).cpu().detach().numpy()[0,1]

the output is different from time to time. Although the same values are often encounterd. It seems like things are not well synchronized. For model_cpu, behaviour is as expected.

I found a different topic here, with a similar problem: Different behavior of advanced indexing on CPU and GPU
However, there it is mentioned that the problem is in assigning different values. I am not doing that, I am adding things to a vector.

Is indexing of multiple values in the same expression something that should be avoided altogether? Are there workarounds?

Thanks!

Hi @zweetvoetje,
I raised an issue several months ago which, I think, touches the same problem as you have: Is there any alternative to numpy.add.at in PyTorch?

Likely, coordlist can have duplicate values, so += is ambiguous in this case and can cause undefined behavior on GPU. I would change

tuss[j,:,self.coordlist[:]] += x[j,:,:]

to

torch.index_add_(tuss[j], self.coordlist, x[j])

(I hope I am not mistaken about syntax of index_add_)

Thank you for helping me out.

Yes, by replacing the expression to index_add_, the results are more or less the same. Although there is some round-off differences between CPU and GPU, these errors are of the order 1e-8, so that’s the behavior that I expected.

For completeness: I had to replace the command

tuss[j,:,self.coordlist[:]] += x[j,:,:]

with:

tuss[j].index_add_(1, self.coordlist, x[j])

I’ve also noticed that the old version yielded wrong results on the CPU as well.