Something wrong with scatter on GPU but not on CPU

Hi all,
I want to set the k-th smallest value to zero, I get the solution from the network, and it success on my host device(win10, cpu mode), but failed on my server device(ubuntu, gpu mode). I don’t know what’s wrong with my code, could anyone help me?


raw tensor:

suceess on my win10 and run with cpu mode:

failed on my server device and run with gpu mode:

I found torch.topk can return the correct indices, but after scatter, the result is wrong on the ubuntu device with gpu mode.

The following code correctly works on my Linux box (Ubuntu 16.04.5 LTS) with a GPU device.

>>> import torch
>>> torch.manual_seed(12345)
<torch._C.Generator object at 0x7fe8d0a3d710>
>>> data = torch.rand(2, 10)
>>> f = torch.FloatTensor(data).cuda()
>>> f = f.view(1, f.shape[1], f.shape[0])
>>> values, indices = torch.topk(f, k=3, dim=1, largest=False)
>>> f
tensor([[[0.9817, 0.8796],
         [0.9921, 0.4611],
         [0.0832, 0.1784],
         [0.3674, 0.5676],
         [0.3376, 0.2119],
         [0.4594, 0.8154],
         [0.9157, 0.2531],
         [0.2133, 0.4770],
         [0.7201, 0.7238],
         [0.3139, 0.6732]]], device='cuda:0')
>>> f.scatter(dim=1, index=indices, source=0)
tensor([[[0.9817, 0.8796],
         [0.9921, 0.4611],
         [0.0000, 0.0000],
         [0.3674, 0.5676],
         [0.3376, 0.0000],
         [0.4594, 0.8154],
         [0.9157, 0.0000],
         [0.0000, 0.4770],
         [0.7201, 0.7238],
         [0.0000, 0.6732]]], device='cuda:0')

I am using PyTorch 1.0.1.

python -c "import torch; print(torch.__version__)"

Please let me know your PyTorch version.

thank you Tony.
I am confused now, I test the code you used and it works correctly with gpu mode, and my pytorch version is 0.4.1.post2.
I don’t know what’s wrong with my project and it get the wrong result.

Could you run your original code using PyTorch 1.0.1?

I’ll try it after the upgrading process.

Tony, my project works correctly after I update my PyTorch version from 0.4.1 to 1.0.0, I don’t find the command update PyTorch to 1.0.1, but it works with 1.0.0, thank you.