DataParallel and nested functions in nn.Module

I am encountering a very strange bug. The code that produces the bug is fairly long so below is a sketch of what is going.

I have an nn.Module that describes a feedforward NN. This module has another nn.Module as a submodule that describes a custom layer. The layer module has a forward() function, some nn.Linear layers and several functions.

class feedforward(nn.Module):
     def __init__(self, num_kernels, levels):
           super(feedforward, self).__init__()
           self.layer = custom_layer()

     def forward(self, inpt):
           return self.layer(inpt)


class custom_layer(nn.Module):
     def __init__(self, num_kernels, levels):
           super(custom_layer, self).__init__()
           self.l1 = nn.Linear(...)

     def func_that_uses_l1(self, other):
           print('Devices #2',  other.device, self.l1.weight.device)
           return self.l1(other)

     def somefunc(self, inpt):
           device = self.l1.weight.device
           other = torch.arange(10, device = device)
           print('Devices #1',  other.device, self.l1.weight.device)
           something = self.func_that_uses_l1(other)
           return function of inpt and something

     def forward(self, inpt):
           return self.somefunc(inpt)

net = nn.DataParallel(feedforward()).to('cuda:0'), device_ids=[0,1,2,3])
outpt = net(inpt)

The strange thing that happens is that in somefunc the devices are equal and correctly distributed over the GPUs, i.e. self.l1.weight is cuda:0, cuda:1, cuda:2 and so on but in func_that_uses_l1 self.l1 is always on device cuda:0, i.e. the output is the following:

Devices#1 cuda:1 cuda:1
Devices#2 cuda:1 cuda:0
Devices#1 cuda:2 cuda:2
Devices#2 cuda:2 cuda:0
Devices#1 cuda:3 cuda:3
Devices#2 cuda:3 cuda:0

This obviously raises a “RuntimeError: arguments are located on different GPUs”

How can something like this happen? Am I not allowed to nest functions when I am using nn.DataParallel? It is so odd because I make sure that other is on the correct device (and it works properly) but then when func_that_uses_l1 is called, self.l1 is magically back on cuda:0.

Note: The code is a sketch and not a minimal working example.

Update: The error vanishes when I get rid of the nesting, i.e. when I change custom_layer to this:

class custom_layer(nn.Module):
     def __init__(self, num_kernels, levels):
           super(custom_layer, self).__init__()
           self.l1 = nn.Linear(...)


     def somefunc(self, inpt):

          def func_that_uses_l1(other):
                print('Devices #2',  other.device, self.l1.weight.device)
                return self.l1(other)

           device = self.l1.weight.device
           other = torch.arange(10, device = device)
           print('Devices #1',  other.device, self.l1.weight.device)
           something = func_that_uses_l1(other)
           return function of inpt and something

     def forward(self, inpt):
           return self.somefunc(inpt)

But this is problematic for my code because I would like to reuse func_that_uses_l1. Is this behavior intended?

The issue is most likely raised, as nn.DataParallel uses the forward method to send each chunk of the data to the appropriate device.
I assume that you are calling func_that_uses_l1 directly, which won’t work.
You could try to pass a flag to forward to call the corresponding method.
Alternatively you would have to use scatter and gather the data manually in fun_that_uses_l1.

I am not calling func_that_uses_l1 directly. It is being called by forward method. Specifically, the forward function calls somefunc which calls func_that_uses_l1. If I remove the nesting, the problem vanishes.

I am 95% certain this is a bug.

Could you post a reproducible code snippet, which shows this bug, so that we could debug and fix it?