I want to convert torch.nn.Linear modules to weight drop linear modules in my model (possibly big), and I want to train my model with multi-GPUs. However, I have RuntimeError in my sample code. First, I have _weight_drop() which drops some part of weights in torch.nn.Linear (see the code below).
##This code is modifed from torchnlp.nn.weight_drop
import torch
from torch.nn import Parameter
def _weight_drop(module, weights, dropout):
for name_w in weights:
w = getattr(module, name_w)
del module._parameters[name_w]
module.register_parameter(name_w + '_raw', Parameter(w))
original_module_forward = module.forward
def forward(*args, **kwargs):
#device = args[0].device
for name_w in weights:
raw_w = getattr(module, name_w + '_raw')
#raw_w.to(device)
w = torch.nn.functional.dropout(raw_w, p=dropout, training=module.training)
#w.to(device)
setattr(module, name_w, w)
return original_module_forward(*args)
def extra_repr(*args):
bias = module.bias_raw or module.bias
return 'in_features={}, out_features={}, bias={}, drop_prob={}'.format(module.in_features, module.out_features, bias is not None, dropout)
setattr(module, 'forward', forward)
setattr(module, 'extra_repr', extra_repr)
I also refer to the tutorial code for torch.nn.DataParallel and construct the two-layer network (Model) with weight drop.
#This code is modified from pytorch tutorial for "DataParallel"
from torch.utils.data import Dataset, DataLoader
input_size = 5
hidden_size = 5
output_size =2
batch_size = 30
data_size = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
batch_size=batch_size, shuffle=True)
class Model(torch.nn.Module):
# Our model
def __init__(self, D_in, H, D_out):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, input):
h_relu = self.linear1(input).clamp(min=0)
output = self.linear2(h_relu)
print("\tIn Model: input size", input.size(),
"output size", output.size(), torch.cuda.current_device())
return output
model = Model(input_size, hidden_size, output_size)
linear_module_list = [v for v in model.named_modules() if isinstance(v[1], torch.nn.Linear)]
for name, module in linear_module_list:
_weight_drop(module, ['weight'], dropout=0.5)
model.to(device)
model = torch.nn.DataParallel(model)
data = list(rand_loader)[0]
while(True):
input = data.to(device)
output = model(input)
print("Outside: input size", input.size(),
"output_size", output.size())
However, “output=model(input)” is not computed in this code with this error message
In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) 0
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-4a83ee11bad2> in <module>
2 while(True):
3 input = data.to(device)
----> 4 output = model(input)
5 print("Outside: input size", input.size(),
6 "output_size", output.size())
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
150 return self.module(*inputs[0], **kwargs[0])
151 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 152 outputs = self.parallel_apply(replicas, inputs, kwargs)
153 return self.gather(outputs, self.output_device)
154
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
160
161 def parallel_apply(self, replicas, inputs, kwargs):
--> 162 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
163
164 def gather(self, outputs, output_device):
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
81 output = results[i]
82 if isinstance(output, Exception):
---> 83 raise output
84 outputs.append(output)
85 return outputs
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py in _worker(i, module, input, kwargs, device)
57 if not isinstance(input, (list, tuple)):
58 input = (input,)
---> 59 output = module(*input, **kwargs)
60 with lock:
61 results[i] = output
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
<ipython-input-2-ac863085502c> in forward(self, input)
33
34 def forward(self, input):
---> 35 h_relu = self.linear1(input).clamp(min=0)
36 output = self.linear2(h_relu)
37 print("\tIn Model: input size", input.size(),
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
<ipython-input-1-66fbe470597c> in forward(*args, **kwargs)
18 #w.to(device)
19 setattr(module, name_w, w)
---> 20 return original_module_forward(*args)
21
22 def extra_repr(*args):
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
90 @weak_script_method
91 def forward(self, input):
---> 92 return F.linear(input, self.weight, self.bias)
93
94 def extra_repr(self):
~/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1404 if input.dim() == 2 and bias is not None:
1405 # fused op is marginally faster
-> 1406 ret = torch.addmm(bias, input, weight.t())
1407 else:
1408 output = input.matmul(weight.t())
RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/THC/generic/THCTensorMathBlas.cu:255
The main reason for this error is that I try to compute linear multiplication between two tensors belonging to different GPUs. I try to modify my _weight_drop() function to manually assign the current device in the DataParallel process, but it does not work. Is there any idea to figure out this problem? This code works fine in single GPU or CPU mode