Problem with DataParallel

I’m getting an assertion error for the following code when I use the DataParallel package (it works without DataParallel). The issue is with self.lin. self.add_module didn’t help either.

class MyCustomNet(torch.nn.Module):
  def __init__(self):
    super(MyCustomNet, self).__init__()
    self.lin = torch.nn.Linear(10, 1)
    self.feature_to_select = torch.nn.Parameter(torch.LongTensor(np.random.randint(0, 10, (100,))), requires_grad=False)

  def forward(self, x):
    return x.index_select(1, self.feature_to_select)

net = MyCustomNet()
net.eval()
net.cuda()
data = torch.autograd.Variable(torch.FloatTensor(np.random.random((10, 10))), volatile=True).cuda()

net = torch.nn.DataParallel(net)
out = net(data)

and here is the error.

Traceback (most recent call last):
  File "reproduce_bug.py", line 52, in <module>
    out = net(data)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 224, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 60, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 70, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/parallel_apply.py", line 67, in parallel_apply
    raise output
AssertionError

I’m having trouble reproducing your issue. With or without DataParallel your code seems to work for me. Are you using the latest version of PyTorch?

>>> import torch
>>> torch.__version__
'0.2.0_3'