DataParallel with Sparse Tensors

I’m trying to use a sparse matrix in my model. The sparse matrix is fixed and not being updated. I’m using a custom function for this. However, I’m getting the following error when I use DataParallel. Any ideas how to solve this?

RuntimeError: zeros is not implemented for type torch.cuda.sparse.

And here is the code. It works if you comment out the DataParallel line.

import torch

class MyFunction(torch.autograd.Function):
  def forward(self, tensor, sparse_tensor):
    print sparse_tensor.size(), tensor.size()
    return torch.matmul(sparse_tensor, tensor.t()).t()

  def backward(self, grad_output):
    sparse_tensor, = self.saved_tensors
    print sparse_tensor.size(), grad_output.size()
    return torch.matmul(sparse_tensor.t(), grad_output.t()), None

class MyCustomNet(torch.nn.Module):
  def __init__(self):
    super(MyCustomNet, self).__init__()
    i = torch.LongTensor([[0, 1, 1],
                          [2, 0, 2]])
    v = torch.FloatTensor([3, 4, 5])
    self.sparse_map = torch.nn.Parameter(torch.sparse.FloatTensor(i, v, torch.Size([2,3])))
    print torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()
  def forward(self, x):
    self.func = MyFunction()
    return self.func(x, self.sparse_map)

net = MyCustomNet()
net = torch.nn.DataParallel(net)
data = torch.autograd.Variable(torch.zeros((10, 3))).cuda()
target = torch.autograd.Variable(torch.zeros((10, 2))).cuda()

out = net(data)
loss_fn = torch.nn.L1Loss()
loss = loss_fn(out, target)

torch.zeros isn’t implemented for sparse tensors yet, though that isn’t hard to do. You could open a feature request here: about getting sparse tensors to work with DataParallel if you feel strongly about it.

Out of curiosity, what are you using sparse tensors for?