DataParallel gives "arguments are located on different GPUs" when you assign to `forward()`

I met an interesting (and breaking) behaviour with DataParallel when I override forward() by assignment. Something like:

class MyEmbedding(nn.Embedding):
    def __init__(self, swap_fwd=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if swap_fwd:
            self.FC = nn.Linear(self.embedding_dim, self.embedding_dim)
            self.forward = self.forward_
        else:
            self.FC = None

    def forward_(self, inputs):
        embeds = super().forward(inputs)
        return self.FC(embeds)

This works fine on a single gpu, but with multiple, I get “arguments are located on different GPUs”. It seems that when _forward() is executing, self is not the appropriate replica (on the correct gpu), but is the original object (typically on gpu 0).

If instead I override forward() in the normal way, everything works fine. Something like:

...
    def forward(self, inputs):
        embeds = super().forward(inputs)
        if self.FC is not None:
            return self.FC(embeds)
        return embeds

I’d be interested to know what’s going on?

I assume that you are calling model.forward_() in your code instead of the model directly?
If so, then this breaking behavior is expected, as utility methods such as hooks and data parallel rely on the “standard” usage of calling the model directly.
For nn.DataParallel, the model.__call__ method would call into DataParallel.forward, which takes care of the parallelization as seen here.

No, I use the usual __call__ interface (ie: treat the module instance as a function). It actually calls my forward_() method, but the self in that call is the original module instance, not the replica created by DataParallel. That means it’s parameters are not on the same gpu as the inputs passed to forward.

I suspect it’s something related to the way python dynamically populates the self variable - when I do self.forward = self.forward_, the current self (ie: the original one that executes __init__) is associated with forward_, and this association isn’t changed when DataParallel clones the module. I looked at it with a debugger, and id(self) within the forward_ method is different to id(self) in Module.__call__() immediately beforehand (at least for threads working with a different gpu).

I put this here partly for others to find, should they encounter this behaviour, and partly because I don’t understand how that can happen, how the mechanism works to cause this behaviour.