Pytorch with Multi GPUs

Hi @ptrblck,

I got an error on running pytorch using 4 GPUs;

the device is defined to be;

device = 'cuda:0, cuda:1, cuda:2, cuda:3'
model = GCN().to(device)
model = torch.nn.DataParrallel(model)

and below is the model;

class GraphConvolution(Module):
    def __init__(self, in_features=100, out_features=20, dropout=0., act=F.relu):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.init = math.sqrt(6.0 / (self.in_features + self.out_features))
        self.weight = Parameter(torch.nn.init.uniform_(torch.FloatTensor(in_features, out_features), a=-self.init, b=self.init), requires_grad=True)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, input, adj):
        input = F.dropout(input, self.dropout, self.training)
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        output = self.act(output)
        return output

The error is;

 File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 148, in forward
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
  File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 159, in scatter
    return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 36, in scatter_kwargs
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
  File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 28, in scatter
    res = scatter_map(inputs)
  File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 15, in scatter_map
    return list(zip(*map(scatter_map, obj)))
  File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 13, in scatter_map
    return Scatter.apply(target_gpus, None, dim, obj)
  File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 89, in forward
    outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
  File "/home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/cuda/comm.py", line 147, in scatter
    return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
RuntimeError: sparse tensors do not have strides (strides at /opt/conda/conda-bld/pytorch_1570910687650/work/aten/src/ATen/SparseTensorImpl.cpp:52)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7fc0b0442687 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: at::SparseTensorImpl::strides() const + 0xae (0x7fc0b35b435e in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: at::native::slice(at::Tensor const&, long, long, long, long) + 0xd8 (0x7fc0b394a258 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x1b5a209 (0x7fc0b3c3f209 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #4: <unknown function> + 0x1ba112b (0x7fc0b3c8612b in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #5: at::native::narrow(at::Tensor const&, long, long, long) + 0x14a (0x7fc0b3953c4a in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #6: <unknown function> + 0x1b5a059 (0x7fc0b3c3f059 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #7: <unknown function> + 0x1a648c6 (0x7fc0b3b498c6 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #8: <unknown function> + 0x15ca0b7 (0x7fc0b36af0b7 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #9: at::native::split(at::Tensor const&, long, long) + 0x1c3 (0x7fc0b3954603 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #10: <unknown function> + 0x1b5a219 (0x7fc0b3c3f219 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #11: <unknown function> + 0x1ba0a83 (0x7fc0b3c85a83 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #12: <unknown function> + 0x36a6966 (0x7fc0b578b966 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #13: <unknown function> + 0x1ba0a83 (0x7fc0b3c85a83 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #14: at::native::chunk(at::Tensor const&, long, long) + 0x131 (0x7fc0b3952611 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #15: <unknown function> + 0x1b59e79 (0x7fc0b3c3ee79 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #16: <unknown function> + 0x36e34ab (0x7fc0b57c84ab in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #17: <unknown function> + 0x1ba0a83 (0x7fc0b3c85a83 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #18: at::Tensor::chunk(long, long) const + 0xe9 (0x7fc0b388c279 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #19: torch::cuda::scatter(at::Tensor const&, c10::ArrayRef<long>, c10::optional<std::vector<long, std::allocator<long> > > const&, long, c10::optional<std::vector<c10::optional<c10::cuda::CUDAStream>, std::allocator<c10::optional<c10::cuda::CUDAStream> > > > const&) + 0x29c (0x7fc0b62fc69c in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #20: <unknown function> + 0x793cbb (0x7fc0e45efcbb in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #21: <unknown function> + 0x206506 (0x7fc0e4062506 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #22: _PyMethodDef_RawFastCallKeywords + 0x254 (0x5641b859d744 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #23: _PyCFunction_FastCallKeywords + 0x21 (0x5641b859d861 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x4ecd (0x5641b86092bd in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #25: _PyEval_EvalCodeWithName + 0x2f9 (0x5641b854d539 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #26: _PyFunction_FastCallKeywords + 0x325 (0x5641b859cef5 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #27: _PyEval_EvalFrameDefault + 0x4b39 (0x5641b8608f29 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #28: _PyFunction_FastCallDict + 0x10b (0x5641b854e56b in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #29: THPFunction_apply(_object*, _object*) + 0x8d6 (0x7fc0e42fbe96 in /home/ali-admin/anaconda2/envs/AliEnv/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #30: _PyMethodDef_RawFastCallKeywords + 0x1f0 (0x5641b859d6e0 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #31: _PyCFunction_FastCallKeywords + 0x21 (0x5641b859d861 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #32: _PyEval_EvalFrameDefault + 0x4ecd (0x5641b86092bd in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #33: _PyEval_EvalCodeWithName + 0xac9 (0x5641b854dd09 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #34: _PyFunction_FastCallDict + 0x1d5 (0x5641b854e635 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #35: <unknown function> + 0x1526a2 (0x5641b858b6a2 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #36: PyIter_Next + 0xe (0x5641b8560c8e in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #37: PySequence_Tuple + 0xfb (0x5641b85af5cb in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #38: _PyEval_EvalFrameDefault + 0x5d84 (0x5641b860a174 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #39: _PyEval_EvalCodeWithName + 0xac9 (0x5641b854dd09 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #40: _PyFunction_FastCallKeywords + 0x387 (0x5641b859cf57 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #41: _PyEval_EvalFrameDefault + 0x416 (0x5641b8604806 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #42: _PyEval_EvalCodeWithName + 0xac9 (0x5641b854dd09 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #43: _PyFunction_FastCallKeywords + 0x387 (0x5641b859cf57 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #44: _PyEval_EvalFrameDefault + 0x416 (0x5641b8604806 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #45: _PyEval_EvalCodeWithName + 0x2f9 (0x5641b854d539 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #46: _PyFunction_FastCallKeywords + 0x387 (0x5641b859cf57 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #47: _PyEval_EvalFrameDefault + 0x14dc (0x5641b86058cc in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #48: _PyFunction_FastCallKeywords + 0xfb (0x5641b859cccb in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #49: _PyEval_EvalFrameDefault + 0x4b39 (0x5641b8608f29 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #50: _PyEval_EvalCodeWithName + 0x2f9 (0x5641b854d539 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #51: _PyFunction_FastCallDict + 0x1d5 (0x5641b854e635 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #52: _PyObject_Call_Prepend + 0x63 (0x5641b856ce53 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #53: PyObject_Call + 0x6e (0x5641b855fdbe in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #54: _PyEval_EvalFrameDefault + 0x1e42 (0x5641b8606232 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #55: _PyEval_EvalCodeWithName + 0x2f9 (0x5641b854d539 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #56: _PyFunction_FastCallDict + 0x1d5 (0x5641b854e635 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #57: _PyObject_Call_Prepend + 0x63 (0x5641b856ce53 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #58: <unknown function> + 0x16ba3a (0x5641b85a4a3a in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #59: _PyObject_FastCallKeywords + 0x49b (0x5641b85a58fb in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #60: _PyEval_EvalFrameDefault + 0x4a96 (0x5641b8608e86 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #61: _PyFunction_FastCallKeywords + 0xfb (0x5641b859cccb in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #62: _PyEval_EvalFrameDefault + 0x416 (0x5641b8604806 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)
frame #63: _PyEval_EvalCodeWithName + 0x2f9 (0x5641b854d539 in /home/ali-admin/anaconda2/envs/AliEnv/bin/python)

Which PyTorch version are you using?
I cannot reproduce this error using a dummy code snippet as:

model = GraphConvolution().to(device)
print(model.weight.device)
model = torch.nn.DataParallel(model)
out = model(torch.randn(8, 100).cuda(), torch.zeros(20, 1).cuda())

Note that device is not properly defined, as you should push the model to the default device (one device only).
nn.DataParallel will then split the input data in the batch dimension and push each chunk to the right device.

Could you post an executable code snippet to reproduce this error (using random tensors), please? :slight_smile:

PS: I don’t think it’s a good idea to tag certain people, as this might discourage others to post an answer. :wink:

Thanks for getting back to me and that PS, I’m new to this community and I thought we should mention someone. Anyways! Now I ask everyone to please help me to tackle with the issue that I have. :slightly_smiling_face:

by the way, it’s a graph network and it doesn’t use batch optimization. could it be the problem maybe?

pytorch version; 1.4.0
here is the full code:


import argparse
import numpy as np
import scipy.sparse as sp
import torch
from torch import optim
import torch.nn.functional as F
import torch.nn as nn
from gae.optimizer import loss_function
from scipy import sparse
from gae.layers import GraphConvolution

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gcn_vae', help="models used")
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--hidden1', type=int, default=32, help='Number of units in hidden layer 1.')
parser.add_argument('--hidden2', type=int, default=4, help='Number of units in hidden layer 2.')
parser.add_argument('--lr', type=float, default=0.0005, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset-str', type=str, default='cora', help='type of dataset.')

args = parser.parse_args()

device = 'cuda'

def loss_function(preds, labels, mu, logvar, n_nodes, norm, pos_weight):
    cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)
    KLD = -0.5 / n_nodes * torch.mean(torch.sum(
        1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1))
    return cost + KLD


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


class GCNModelVAE(nn.Module):
    def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2, dropout):
        super(GCNModelVAE, self).__init__()
        self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu)
        self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.gc3 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.dc = InnerProductDecoder(dropout, act=lambda x: x)

    def encode(self, x, adj):
        hidden1 = self.gc1(x, adj)
        return self.gc2(hidden1, adj), self.gc3(hidden1, adj)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, adj):
        mu, logvar = self.encode(x, adj)
        z = self.reparameterize(mu, logvar)
        return self.dc(z), mu, logvar


class InnerProductDecoder(nn.Module):
    """Decoder for using inner product for prediction."""

    def __init__(self, dropout, act=torch.sigmoid):
        super(InnerProductDecoder, self).__init__()
        self.dropout = dropout
        self.act = act

    def forward(self, z):
        z = F.dropout(z, self.dropout, training=self.training)
        adj = self.act(torch.mm(z, z.t()))
        return adj

def preprocess_graph(adj):
    adj = sp.coo_matrix(adj)
    adj_ = adj + sp.eye(adj.shape[0])
    rowsum = np.array(adj_.sum(1))
    degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
    adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
    # return sparse_to_tuple(adj_normalized)
    return sparse_mx_to_torch_sparse_tensor(adj_normalized)

def gae_for(args):
    features = torch.randint(0, 2, size=(2708, 1433)).to(device)
    adj = np.asarray(np.random.randint(0, 2, size=(2708, 2708)))
    adj = sparse.csr_matrix(adj)

    n_nodes, feat_dim = features.shape

    # Store original adjacency matrix (without diagonal entries) for later
    adj_orig = adj
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()

    adj_train = adj

    # Some preprocessing
    adj_norm = preprocess_graph(adj)
    adj_label = adj_train + sp.eye(adj_train.shape[0])
    # adj_label = sparse_to_tuple(adj_label)
    adj_label = torch.FloatTensor(adj_label.toarray()).to(device)

    pos_weight = torch.tensor(float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()).float().to(device)
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

    model = GCNModelVAE(feat_dim, args.hidden1, args.hidden2, args.dropout).to(device)
    model = nn.DataParallel(model)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(args.epochs):
        print(epoch)
        model.train()
        optimizer.zero_grad()
        recovered, mu, logvar = model(features, adj_norm)
        loss = loss_function(preds=recovered, labels=adj_label,
                             mu=mu, logvar=logvar, n_nodes=n_nodes,
                             norm=norm, pos_weight=pos_weight)
        loss.backward()
        optimizer.step()
if __name__ == '__main__':
    gae_for(args)

And is GraphConvolution:

import torch
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class GraphConvolution(Module):

    def __init__(self, in_features, out_features, dropout=0., act=F.relu):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, input, adj):
        input = F.dropout(input, self.dropout, self.training)
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        output = self.act(output)
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

And I get this error;

raceback (most recent call last):
  File "/home/ali-admin/pytorch-examples/gae/gae/train.py", line 142, in <module>
    gae_for(args)
  File "/home/ali-admin/pytorch-examples/gae/gae/train.py", line 131, in gae_for
    recovered, mu, logvar = model(features, adj_norm)
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 148, in forward
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 159, in scatter
    return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 36, in scatter_kwargs
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 28, in scatter
    res = scatter_map(inputs)
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 15, in scatter_map
    return list(zip(*map(scatter_map, obj)))
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 13, in scatter_map
    return Scatter.apply(target_gpus, None, dim, obj)
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 89, in forward
    outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
  File "/home/ali-admin/.local/lib/python3.6/site-packages/torch/cuda/comm.py", line 147, in scatter
    return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
RuntimeError: sparse tensors do not have strides (strides at /pytorch/aten/src/ATen/SparseTensorImpl.cpp:52)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7f7651ea8813 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: at::SparseTensorImpl::strides() const + 0xae (0x7f75d5ca1b8e in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #2: at::native::slice(at::Tensor const&, long, long, long, long) + 0xd8 (0x7f75d6037a88 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x1f5da39 (0x7f75d632ca39 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #4: <unknown function> + 0x1fa495b (0x7f75d637395b in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #5: at::native::narrow(at::Tensor const&, long, long, long) + 0x14a (0x7f75d604147a in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #6: <unknown function> + 0x1f5d889 (0x7f75d632c889 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #7: <unknown function> + 0x1e680f6 (0x7f75d62370f6 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #8: <unknown function> + 0x19cd8e7 (0x7f75d5d9c8e7 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #9: at::native::split(at::Tensor const&, long, long) + 0x1c3 (0x7f75d6041e33 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #10: <unknown function> + 0x1f5da49 (0x7f75d632ca49 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #11: <unknown function> + 0x1fa42b3 (0x7f75d63732b3 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #12: <unknown function> + 0x3aaa196 (0x7f75d7e79196 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #13: <unknown function> + 0x1fa42b3 (0x7f75d63732b3 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #14: at::native::chunk(at::Tensor const&, long, long) + 0x131 (0x7f75d603fe41 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #15: <unknown function> + 0x1f5d6a9 (0x7f75d632c6a9 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #16: <unknown function> + 0x3ae6cdb (0x7f75d7eb5cdb in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #17: <unknown function> + 0x1fa42b3 (0x7f75d63732b3 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #18: at::Tensor::chunk(long, long) const + 0xe9 (0x7f75d5f79aa9 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #19: torch::cuda::scatter(at::Tensor const&, c10::ArrayRef<long>, c10::optional<std::vector<long, std::allocator<long> > > const&, long, c10::optional<std::vector<c10::optional<c10::cuda::CUDAStream>, std::allocator<c10::optional<c10::cuda::CUDAStream> > > > const&) + 0x29c (0x7f75d89e9ecc in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #20: <unknown function> + 0x77fd8f (0x7f7652c82d8f in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #21: <unknown function> + 0x211014 (0x7f7652714014 in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #22: /usr/bin/python3.6() [0x50a635]
frame #23: _PyEval_EvalFrameDefault + 0x444 (0x50bfb4 in /usr/bin/python3.6)
frame #24: /usr/bin/python3.6() [0x507d64]
frame #25: /usr/bin/python3.6() [0x509a90]
frame #26: /usr/bin/python3.6() [0x50a48d]
frame #27: _PyEval_EvalFrameDefault + 0x444 (0x50bfb4 in /usr/bin/python3.6)
frame #28: /usr/bin/python3.6() [0x507d64]
frame #29: /usr/bin/python3.6() [0x588c8b]
frame #30: PyObject_Call + 0x3e (0x59fc4e in /usr/bin/python3.6)
frame #31: THPFunction_apply(_object*, _object*) + 0xa4f (0x7f76529a24af in /home/ali-admin/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #32: /usr/bin/python3.6() [0x50a2bf]
frame #33: _PyEval_EvalFrameDefault + 0x444 (0x50bfb4 in /usr/bin/python3.6)
frame #34: /usr/bin/python3.6() [0x507d64]
frame #35: _PyFunction_FastCallDict + 0x2e2 (0x509042 in /usr/bin/python3.6)
frame #36: _PyObject_FastCallDict + 0x4f1 (0x5a4b71 in /usr/bin/python3.6)
frame #37: /usr/bin/python3.6() [0x514ad6]
frame #38: PySequence_Tuple + 0x1e3 (0x5a5943 in /usr/bin/python3.6)
frame #39: _PyEval_EvalFrameDefault + 0x58d9 (0x511449 in /usr/bin/python3.6)
frame #40: /usr/bin/python3.6() [0x507d64]
frame #41: /usr/bin/python3.6() [0x509a90]
frame #42: /usr/bin/python3.6() [0x50a48d]
frame #43: _PyEval_EvalFrameDefault + 0x444 (0x50bfb4 in /usr/bin/python3.6)
frame #44: /usr/bin/python3.6() [0x507d64]
frame #45: /usr/bin/python3.6() [0x509a90]
frame #46: /usr/bin/python3.6() [0x50a48d]
frame #47: _PyEval_EvalFrameDefault + 0x444 (0x50bfb4 in /usr/bin/python3.6)
frame #48: /usr/bin/python3.6() [0x507d64]
frame #49: /usr/bin/python3.6() [0x509a90]
frame #50: /usr/bin/python3.6() [0x50a48d]
frame #51: _PyEval_EvalFrameDefault + 0x1226 (0x50cd96 in /usr/bin/python3.6)
frame #52: /usr/bin/python3.6() [0x509758]
frame #53: /usr/bin/python3.6() [0x50a48d]
frame #54: _PyEval_EvalFrameDefault + 0x444 (0x50bfb4 in /usr/bin/python3.6)
frame #55: /usr/bin/python3.6() [0x507d64]
frame #56: _PyFunction_FastCallDict + 0x2e2 (0x509042 in /usr/bin/python3.6)
frame #57: /usr/bin/python3.6() [0x594931]
frame #58: PyObject_Call + 0x3e (0x59fc4e in /usr/bin/python3.6)
frame #59: _PyEval_EvalFrameDefault + 0x17e6 (0x50d356 in /usr/bin/python3.6)
frame #60: /usr/bin/python3.6() [0x507d64]
frame #61: _PyFunction_FastCallDict + 0x2e2 (0x509042 in /usr/bin/python3.6)
frame #62: /usr/bin/python3.6() [0x594931]
frame #63: /usr/bin/python3.6() [0x54a941]

Thanks for posting the code!
I can reproduce the issue and stumbled upon this issue, which seems to be worked on.
Could you post a link to your use case there and ping for an update?

As a workaround you could call adj_norm.to_dense() for nn.DataParallel or try out DistributedDataParallel.