Hello,
I am trying to access to a model’s attributes that has been wrapped by torch.nn.DataParallel
.
Following the tutorial, overriding __getattr__()
is enough. The simple following example shows a recursive issue, unless I am missing something:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.block1 = nn.Linear(10, 20)
self.block2 = nn.Linear(20, 20)
self.block3 = nn.Linear(20, 20)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return x
class MyDataParallel(torch.nn.DataParallel):
"""
Allow nn.DataParallel to call model's attributes. (supposedly)
https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
"""
def __getattr__(self, name):
print("Name att.: {}".format(name))
return getattr(self.module, name)
device = torch.device("cpu")
model = Model()
model = MyDataParallel(model)
model.to(device)
x = torch.rand(20, 10)
x = x.to(device)
model(x)
Output:
Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
Name: module
.... (more. Deleted for clarity)
Traceback (most recent call last):
File "test-data-parallel.py", line 37, in <module>
model(x)
File "/home/brian/Venvs/pytorch.1.0.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/home/brian/Venvs/pytorch.1.0.1/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 138, in forward
return self.module(*inputs, **kwargs)
File "test-data-parallel.py", line 26, in __getattr__
return getattr(self.module, name)
File "test-data-parallel.py", line 26, in __getattr__
return getattr(self.module, name)
File "test-data-parallel.py", line 26, in __getattr__
return getattr(self.module, name)
[Previous line repeated 491 more times]
File "test-data-parallel.py", line 25, in __getattr__
print("Name: {}".format(name))
RecursionError: maximum recursion depth exceeded while calling a Python object
The line getattr(self.module, name)
calls recursively MyDataParallel.__getattr__()
, since it calls self.module
.
Any idea how to allow accessing to the attributes of a model that was wrapped within torch.nn.DataParallel
?
The trivial solution of changing, within the entire code, the call model.attribute
to model.module.attribute
is not considered.
Thank you!