Access att. of model wrapped within torch.nn.DataParallel: maximum recursion depth exceeded

Hello,
I am trying to access to a model’s attributes that has been wrapped by torch.nn.DataParallel.
Following the tutorial, overriding __getattr__() is enough. The simple following example shows a recursive issue, unless I am missing something:

import torch
import torch.nn as nn


class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.block1 = nn.Linear(10, 20)
        self.block2 = nn.Linear(20, 20)
        self.block3 = nn.Linear(20, 20)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x


class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes. (supposedly)
    https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
    """
    def __getattr__(self, name):
        print("Name att.: {}".format(name))
        return getattr(self.module, name)


device = torch.device("cpu")

model = Model()
model = MyDataParallel(model)
model.to(device)

x = torch.rand(20, 10)
x = x.to(device)
model(x)

Output:

Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
.... (more. Deleted for clarity)
Traceback (most recent call last):
  File "test-data-parallel.py", line 37, in <module>
    model(x)
  File "/home/brian/Venvs/pytorch.1.0.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/brian/Venvs/pytorch.1.0.1/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 138, in forward
    return self.module(*inputs, **kwargs)
  File "test-data-parallel.py", line 26, in __getattr__
    return getattr(self.module, name)
  File "test-data-parallel.py", line 26, in __getattr__
    return getattr(self.module, name)
  File "test-data-parallel.py", line 26, in __getattr__
    return getattr(self.module, name)
  [Previous line repeated 491 more times]
  File "test-data-parallel.py", line 25, in __getattr__
    print("Name: {}".format(name))
RecursionError: maximum recursion depth exceeded while calling a Python object

The line getattr(self.module, name) calls recursively MyDataParallel.__getattr__(), since it calls self.module.

Any idea how to allow accessing to the attributes of a model that was wrapped within torch.nn.DataParallel?

The trivial solution of changing, within the entire code, the call model.attribute to model.module.attribute is not considered.

Thank you!

This workaround breaks down the recursion:

class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes.
    """
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)
2 Likes