RuntimeError: arguments are located on different GPUs for pytorch 1.1.0 with 2 GPUs

I want to convert torch.nn.Linear modules to weight drop linear modules in my model (possibly big), and I want to train my model with multi-GPUs. However, I have RuntimeError in my sample code. First, I have _weight_drop() which drops some part of weights in torch.nn.Linear (see the code below).

##This code is modifed from torchnlp.nn.weight_drop
import torch
from torch.nn import Parameter

def _weight_drop(module, weights, dropout):
    for name_w in weights:
        w = getattr(module, name_w)
        del module._parameters[name_w]
        module.register_parameter(name_w + '_raw', Parameter(w))
    original_module_forward = module.forward

    def forward(*args, **kwargs):
        #device = args[0].device
        for name_w in weights:
            raw_w = getattr(module, name_w + '_raw')
            w = torch.nn.functional.dropout(raw_w, p=dropout,
            setattr(module, name_w, w)
        return original_module_forward(*args)    

    def extra_repr(*args):
        bias = module.bias_raw or module.bias
        return 'in_features={}, out_features={}, bias={}, drop_prob={}'.format(module.in_features, module.out_features, bias is not None, dropout)

    setattr(module, 'forward', forward)
    setattr(module, 'extra_repr', extra_repr)

I also refer to the tutorial code for torch.nn.DataParallel and construct the two-layer network (Model) with weight drop.

#This code is modified from pytorch tutorial for "DataParallel"
from import Dataset, DataLoader

input_size = 5
hidden_size = 5
output_size =2
batch_size = 30
data_size = 100

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length = torch.randn(length, size)

    def __getitem__(self, index):

    def __len__(self):
        return self.len

rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                         batch_size=batch_size, shuffle=True)

class Model(torch.nn.Module):
    # Our model
    def __init__(self, D_in, H, D_out):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, input):
        h_relu = self.linear1(input).clamp(min=0)
        output = self.linear2(h_relu)
        print("\tIn Model: input size", input.size(),
              "output size", output.size(), torch.cuda.current_device())
        return output

model = Model(input_size, hidden_size, output_size)
linear_module_list = [v for v in model.named_modules() if isinstance(v[1], torch.nn.Linear)]
for name, module in linear_module_list:
    _weight_drop(module, ['weight'], dropout=0.5)
model = torch.nn.DataParallel(model)

data = list(rand_loader)[0]
    input =
    output = model(input)
    print("Outside: input size", input.size(),
          "output_size", output.size())

However, “output=model(input)” is not computed in this code with this error message

In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) 0
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-4a83ee11bad2> in <module>
      2 while(True):
      3     input =
----> 4     output = model(input)
      5     print("Outside: input size", input.size(),
      6           "output_size", output.size())

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/ in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/ in forward(self, *inputs, **kwargs)
    150             return self.module(*inputs[0], **kwargs[0])
    151         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 152         outputs = self.parallel_apply(replicas, inputs, kwargs)
    153         return self.gather(outputs, self.output_device)

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/ in parallel_apply(self, replicas, inputs, kwargs)
    161     def parallel_apply(self, replicas, inputs, kwargs):
--> 162         return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
    164     def gather(self, outputs, output_device):

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/ in parallel_apply(modules, inputs, kwargs_tup, devices)
     81         output = results[i]
     82         if isinstance(output, Exception):
---> 83             raise output
     84         outputs.append(output)
     85     return outputs

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/ in _worker(i, module, input, kwargs, device)
     57                 if not isinstance(input, (list, tuple)):
     58                     input = (input,)
---> 59                 output = module(*input, **kwargs)
     60             with lock:
     61                 results[i] = output

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/ in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

<ipython-input-2-ac863085502c> in forward(self, input)
     34     def forward(self, input):
---> 35         h_relu = self.linear1(input).clamp(min=0)
     36         output = self.linear2(h_relu)
     37         print("\tIn Model: input size", input.size(),

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/ in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

<ipython-input-1-66fbe470597c> in forward(*args, **kwargs)
     19             setattr(module, name_w, w)
---> 20         return original_module_forward(*args)
     22     def extra_repr(*args):

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/ in forward(self, input)
     90     @weak_script_method
     91     def forward(self, input):
---> 92         return F.linear(input, self.weight, self.bias)
     94     def extra_repr(self):

~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/ in linear(input, weight, bias)
   1404     if input.dim() == 2 and bias is not None:
   1405         # fused op is marginally faster
-> 1406         ret = torch.addmm(bias, input, weight.t())
   1407     else:
   1408         output = input.matmul(weight.t())

RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/generic/

The main reason for this error is that I try to compute linear multiplication between two tensors belonging to different GPUs. I try to modify my _weight_drop() function to manually assign the current device in the DataParallel process, but it does not work. Is there any idea to figure out this problem? This code works fine in single GPU or CPU mode

  1. I don’t think you’re supposed to use _weight_drop, use the WeightDrop class from the same source. The super().__init__(...) there should do some magic. Since you’ve modified the function, perhaps write a custom wrapper as well doing the same thing as WeightDrop.
  2. Do not manually assign module gpus for components already wrapped inside DataParallel unless you know what you’re doing.

Thank you for your comment. I tried to use WeightDrop in torchnlp package as well. However, it returned the same error. Since both codes depend on _weight_drop(), I believe that something goes wrong in that function when I wrap my model by torch.nn.DataParallel.
I understand that DataParallel split the input into inputs[] (list of parts of inputs), and inputs[i] is on the i-th GPU. DataParallel also replicates my model wrapped by torch.nn.DataParallel to replicas[] (list of copy models), and replicas[i] is on the i-th GPU. In my case, replicas1 returns runtime same error. the reason for this error is that replicas[1].weight and replicas[1].bias are not properly assigned in the 1st GPU. Could you have any idea for this? Manually assigning such tensors to the target GPU is unstable as you pointed out (I tried, but it returns same error for some input).