DataParallel on modules with dynamically overwritten forwards

Hi. When I use DataParallel on a module whose forward has been overwritten at the instance level, and the new forward calls the original one, it errors out with RuntimeError: arguments are located on different GPUS. Here’s a minimal example:

def forward_wrapper(self, x):
    return self._forward(x)

net = nn.Linear(10, 10)
net._forward = net.forward
net.forward = MethodType(forward_wrapper, net)

net = nn.DataParallel(net).cuda()
x = torch.randn(8, 10).cuda()
net(x)

The following achieves the same thing and runs fine:

class ForwardWrapper(nn.Module):
    def __init__(self, module):
        self.module = module
    def forward(self, x):
        return self.module.forward(x)

net = nn.Linear(10, 10)
net = ForwardWrapper(net)

net = nn.DataParallel(net).cuda()
x = torch.randn(8, 10).cuda()
net(x)

So I’m not blocked by this. But I have no idea why one should work and not the other, and no idea what I would have to do if I perversely insisted on doing things the first way.

I experience the same problem. Did you managed to solve it?

Did any of you manage to solve it?