DataParallel issue with torch.spmm

:bug: Bug

To Reproduce

Steps to reproduce the behavior:

import torch
from torch import nn

class SparseTest(nn.Module):
    def __init__(self):
        super(SparseTest, self).__init__()
        self.S = torch.sparse_coo_tensor(
            indices=torch.tensor([[0, 0, 1, 2], [2, 3, 0, 3]]),
            values=torch.tensor([1.0, 2.0, 1.0, 3.0]),
            size=[3, 4]).cuda()
        self.fc = nn.Linear(6, 4) 

    def forward(self, x):
        self.S = self.S
        x = torch.spmm(self.S, x)
        x = x.reshape(-1)
        x = self.fc(x)
        return x

if __name__ == "__main__":

    X = torch.ones(4, 2, dtype=torch.float).cuda()
    y = torch.zeros(4, dtype=torch.float).cuda()
    sparseTest = SparseTest()
    sparseTest = sparseTest.cuda()
    sparseTest = torch.nn.DataParallel(sparseTest)  # whether use DataParallel
    optimizer = torch.optim.Adam(sparseTest.parameters(), lr=0.001, weight_decay=0.00005)
    lossMSE = nn.MSELoss()
    with torch.set_grad_enabled(True):
        for i in range(10):
            x = sparseTest(X)
            optimizer.zero_grad()
            loss = lossMSE(x, y)
            loss.backward()
            optimizer.step()
            print("loss: {:.8f}".format(loss.item()))

Expected behavior

Without Dataparallel, i.e., # sparseTest = torch.nn.DataParallel(sparseTest)

$ python sparseTest.py 
loss: 1.46217334
loss: 1.42986751
loss: 1.39802396
loss: 1.36665058
loss: 1.33575463
loss: 1.30534196
loss: 1.27541876
loss: 1.24598980
loss: 1.21705961
loss: 1.18863106

With Dataparallel, i.e., sparseTest = torch.nn.DataParallel(sparseTest)

$ python sparseTest.py 
Traceback (most recent call last):
  File "sparseTest.py", line 31, in <module>
    x = sparseTest(X)
  File "/home/pai/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/pai/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/pai/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/pai/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
    raise output
  File "/home/pai/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
    output = module(*input, **kwargs)
  File "/home/pai/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "sparseTest.py", line 15, in forward
    x = torch.spmm(self.S, x)
RuntimeError: addmm: Argument #3 (dense): Expected dim 0 size 4, got 1

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp

Nvidia driver version: 390.116
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.4.2
/usr/local/cuda-9.0/lib64/libcudnn.so.7

Versions of relevant libraries:
[pip3] numpy==1.16.3
[pip3] torch==1.0.1.post2
[pip3] torch-dct==0.1.5
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.2.2.post3
[conda] blas 1.0 mkl
[conda] mkl 2019.3 199
[conda] mkl_fft 1.0.12 py36ha843d7b_0
[conda] mkl_random 1.0.2 py36hd81dba3_0
[conda] pytorch 1.1.0 py3.6_cuda9.0.176_cudnn7.5.1_0 pytorch
[conda] torch-dct 0.1.5 pypi_0 pypi
[conda] torchsummary 1.5.1 pypi_0 pypi

Additional context