Hi. When I use DataParallel on a module whose forward
has been overwritten at the instance level, and the new forward
calls the original one, it errors out with RuntimeError: arguments are located on different GPUS
. Here’s a minimal example:
def forward_wrapper(self, x):
return self._forward(x)
net = nn.Linear(10, 10)
net._forward = net.forward
net.forward = MethodType(forward_wrapper, net)
net = nn.DataParallel(net).cuda()
x = torch.randn(8, 10).cuda()
net(x)
The following achieves the same thing and runs fine:
class ForwardWrapper(nn.Module):
def __init__(self, module):
self.module = module
def forward(self, x):
return self.module.forward(x)
net = nn.Linear(10, 10)
net = ForwardWrapper(net)
net = nn.DataParallel(net).cuda()
x = torch.randn(8, 10).cuda()
net(x)
So I’m not blocked by this. But I have no idea why one should work and not the other, and no idea what I would have to do if I perversely insisted on doing things the first way.