Hello everyone. I wondering if someone know of a work around for the following issue (or if it’s a bug). I have a function defined in a superclass that calls the forward of a custom module. When I call this function in the sub class in a data parallel setup, the custom module is on the wrong device. Is this supposed to happen?
Minimal Example
import torch
import torch.nn as nn
class SuperClass(nn.Module):
def __init__(self):
super(SuperClass, self).__init__()
self.fc2 = nn.Linear(100, 1)
self.mid = nn.Linear(100, 100)
self.f = self.my_layer_func
def my_layer_func(self, x):
return self.mid(x)
class SubClass(SuperClass):
def __init__(self):
super(SubClass, self).__init__()
def forward(self, x):
t = self.f(x)
d = self.fc2(t)
return d
print("Devices", torch.cuda.device_count())
mod = SubClass()
mod = nn.DataParallel(mod)
mod = mod.to(0)
input = torch.randn(2, 100, device=0)
out = mod(input)
Errors
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "<stdin>", line 5, in forward
File "<stdin>", line 8, in my_layer_func
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/functional.py", line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected tensor for 'out' to have the same device as tensor for argument #2 'mat1'; but device 0 does not equal 1 (while checking arguments for addmm)
If I instead change the my_layer_func
to take as input the mid
argument and pass that in during the forward of the subclass (i.e., self.f(x, self.mid)
), things work fine. I’d rather not have to do that if possible.
Any suggestions?