Access att. of model wrapped within torch.nn.DataParallel: maximum recursion depth exceeded

Hello,
I am trying to access to a model’s attributes that has been wrapped by torch.nn.DataParallel.
Following the tutorial, overriding __getattr__() is enough. The simple following example shows a recursive issue, unless I am missing something:

import torch
import torch.nn as nn


class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.block1 = nn.Linear(10, 20)
        self.block2 = nn.Linear(20, 20)
        self.block3 = nn.Linear(20, 20)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x


class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes. (supposedly)
    https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
    """
    def __getattr__(self, name):
        print("Name att.: {}".format(name))
        return getattr(self.module, name)


device = torch.device("cpu")

model = Model()
model = MyDataParallel(model)
model.to(device)

x = torch.rand(20, 10)
x = x.to(device)
model(x)

Output:

Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
.... (more. Deleted for clarity)
Traceback (most recent call last):
  File "test-data-parallel.py", line 37, in <module>
    model(x)
  File "/home/brian/Venvs/pytorch.1.0.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/brian/Venvs/pytorch.1.0.1/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 138, in forward
    return self.module(*inputs, **kwargs)
  File "test-data-parallel.py", line 26, in __getattr__
    return getattr(self.module, name)
  File "test-data-parallel.py", line 26, in __getattr__
    return getattr(self.module, name)
  File "test-data-parallel.py", line 26, in __getattr__
    return getattr(self.module, name)
  [Previous line repeated 491 more times]
  File "test-data-parallel.py", line 25, in __getattr__
    print("Name: {}".format(name))
RecursionError: maximum recursion depth exceeded while calling a Python object

The line getattr(self.module, name) calls recursively MyDataParallel.__getattr__(), since it calls self.module.

Any idea how to allow accessing to the attributes of a model that was wrapped within torch.nn.DataParallel?

The trivial solution of changing, within the entire code, the call model.attribute to model.module.attribute is not considered.

Thank you!

2 Likes

This workaround breaks down the recursion:

class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes.
    """
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)
5 Likes