Custom module super class function fails to move module to to other device DataParallel

Hello everyone. I wondering if someone know of a work around for the following issue (or if it’s a bug). I have a function defined in a superclass that calls the forward of a custom module. When I call this function in the sub class in a data parallel setup, the custom module is on the wrong device. Is this supposed to happen?

Minimal Example

import torch
import torch.nn as nn


class SuperClass(nn.Module):
    def __init__(self):
        super(SuperClass, self).__init__()
        self.fc2 = nn.Linear(100, 1)
        self.mid = nn.Linear(100, 100)
        self.f = self.my_layer_func
    def my_layer_func(self, x):
        return self.mid(x)


class SubClass(SuperClass):
    def __init__(self):
        super(SubClass, self).__init__()
    def forward(self, x):
        t = self.f(x)
        d = self.fc2(t)
        return d


print("Devices", torch.cuda.device_count())
mod = SubClass()
mod = nn.DataParallel(mod)
mod = mod.to(0)

input = torch.randn(2, 100, device=0)
out = mod(input)

Errors

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "<stdin>", line 5, in forward
  File "<stdin>", line 8, in my_layer_func
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/dfs/scratch0/lorr1/py3.8/lib/python3.8/site-packages/torch/nn/functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected tensor for 'out' to have the same device as tensor for argument #2 'mat1'; but device 0 does not equal 1 (while checking arguments for addmm)

If I instead change the my_layer_func to take as input the mid argument and pass that in during the forward of the subclass (i.e., self.f(x, self.mid)), things work fine. I’d rather not have to do that if possible.

Any suggestions?

Question: Can’t you directly call self.my_layer_func(x)?