Index_add_ doesn't work with DataParallel

:bug: Bug

index_add_ (and probably other similar indexing functions like index_copy_, Ps. Not tested) give wrong results when used inside a model which has been wrapped with DataParallel.
Even with DataParallel wrapped model, the forward function which may be using index_add_ for some kind of calculations should work normally as in the case for single GPU.
Refer to the log attached which illustrates the problem.

To Reproduce

Steps to reproduce the behaviour:

  1. Run the below dummy code snippet.
  2. Use 2 GPUs for running(export CUDA_VISIBLE_DEVICES=0,1).
import torch

idx = torch.arange(0,40, device=torch.device('cuda:0'), dtype=torch.long).reshape(4,10)
print("index:", idx.shape)

emb = torch.arange(10,130, dtype=torch.int, device=torch.device('cuda:0')).reshape(4,10,3)
print("t", emb.shape)

print("\n")

class Index_Add_Checker(torch.nn.Module):
    def __init__(self, index, t):
        super().__init__()
        
    def forward(self, index, t):
        index.view(-1)
        pooled = torch.zeros(40, 3, dtype=torch.int).cuda()
        print("index:", index.shape)
        print("t:", t.shape)
        pooled.index_add_(0, index.view(-1), t.view(-1,3))
        return pooled

model_dp = Index_Add_Checker(idx, emb)
model_dp = torch.nn.DataParallel(model_dp).cuda()

ans_dp = model_dp(idx, emb)
print("ans_dp shape:", ans_dp.shape)
print("ans_dp:", ans_dp)

print("\n=====================================================================\n")

model_without_dp = Index_Add_Checker(idx, emb)
ans = model_without_dp(idx, emb)
print("ans shape:", ans.shape)
print("ans:", ans)

Expected behaviour

Basically, the ans and ans_dp should be same, but ans_dp i.e ans in case of data parallel model doesn’t seem to be correct and something which is not expected out of index_add_.
This is probably happening because DataParallel splits the index and t along batch_first=0 dimension. And when they are used for index_add_ the indices do not line up as expected and hence the problem.

Output Log:

index: torch.Size([4, 10])
t torch.Size([4, 10, 3])


index: torch.Size([2, 10])
t: torch.Size([2, 10, 3])
index: torch.Size([2, 10])
t: torch.Size([2, 10, 3])
ans_dp shape: torch.Size([80, 3])
ans_dp: tensor([[ 10,  11,  12],
        [ 13,  14,  15],
        [ 16,  17,  18],
        [ 19,  20,  21],
        [ 22,  23,  24],
        [ 25,  26,  27],
        [ 28,  29,  30],
        [ 31,  32,  33],
        [ 34,  35,  36],
        [ 37,  38,  39],
        [ 40,  41,  42],
        [ 43,  44,  45],
        [ 46,  47,  48],
        [ 49,  50,  51],
        [ 52,  53,  54],
        [ 55,  56,  57],
        [ 58,  59,  60],
        [ 61,  62,  63],
        [ 64,  65,  66],
        [ 67,  68,  69],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [  0,   0,   0],
        [ 70,  71,  72],
        [ 73,  74,  75],
        [ 76,  77,  78],
        [ 79,  80,  81],
        [ 82,  83,  84],
        [ 85,  86,  87],
        [ 88,  89,  90],
        [ 91,  92,  93],
        [ 94,  95,  96],
        [ 97,  98,  99],
        [100, 101, 102],
        [103, 104, 105],
        [106, 107, 108],
        [109, 110, 111],
        [112, 113, 114],
        [115, 116, 117],
        [118, 119, 120],
        [121, 122, 123],
        [124, 125, 126],
        [127, 128, 129]], device='cuda:0', dtype=torch.int32)

=====================================================================

index: torch.Size([4, 10])
t: torch.Size([4, 10, 3])
ans shape: torch.Size([40, 3])
ans: tensor([[ 10,  11,  12],
        [ 13,  14,  15],
        [ 16,  17,  18],
        [ 19,  20,  21],
        [ 22,  23,  24],
        [ 25,  26,  27],
        [ 28,  29,  30],
        [ 31,  32,  33],
        [ 34,  35,  36],
        [ 37,  38,  39],
        [ 40,  41,  42],
        [ 43,  44,  45],
        [ 46,  47,  48],
        [ 49,  50,  51],
        [ 52,  53,  54],
        [ 55,  56,  57],
        [ 58,  59,  60],
        [ 61,  62,  63],
        [ 64,  65,  66],
        [ 67,  68,  69],
        [ 70,  71,  72],
        [ 73,  74,  75],
        [ 76,  77,  78],
        [ 79,  80,  81],
        [ 82,  83,  84],
        [ 85,  86,  87],
        [ 88,  89,  90],
        [ 91,  92,  93],
        [ 94,  95,  96],
        [ 97,  98,  99],
        [100, 101, 102],
        [103, 104, 105],
        [106, 107, 108],
        [109, 110, 111],
        [112, 113, 114],
        [115, 116, 117],
        [118, 119, 120],
        [121, 122, 123],
        [124, 125, 126],
        [127, 128, 129]], device='cuda:0', dtype=torch.int32)

Environment

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.3 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 7.5.17
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti
GPU 4: GeForce GTX 1080 Ti
GPU 5: GeForce GTX 1080 Ti
GPU 6: GeForce GTX 1080 Ti
GPU 7: GeForce GTX 1080 Ti

Nvidia driver version: 418.39
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.5.0
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.6
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn.so.7.0.5

Versions of relevant libraries:
[pip3] numpy==1.14.0
[pip3] numpydoc==0.7.0
[pip3] torch==1.0.1.post2
[pip3] torchvision==0.2.2.post3
[conda] torch                     1.1.0                    pypi_0    pypi
[conda] torch-cluster             1.3.0                    pypi_0    pypi
[conda] torch-geometric           1.2.0                    pypi_0    pypi
[conda] torch-scatter             1.2.0                    pypi_0    pypi
[conda] torch-sparse              0.4.0                    pypi_0    pypi
[conda] torch-spline-conv         1.1.0                    pypi_0    pypi
[conda] torchvision               0.2.2.post3              pypi_0    pypi
[conda] torchviz                  0.0.1                    pypi_0    pypi

This is a cross-post of https://github.com/pytorch/pytorch/issues/21810. If this is a proper issue please continue on GitHub.

Hi! I made a post here because I didn’t get a reply on GitHub issue tracker. 'm new to PyTorch so just wanted to be sure if my post is actually right.