I am trying to use getattr
in the forward
method an nn.Module
class. A minimalistic indicative example of what I’m trying to do is shown below:
class Net(nn.Module):
def __init__(self, ):
super().__init__()
self.layer = nn.Linear(8, 16)
def forward(self, x):
return getattr(self, 'layer')(x)
I get a NotImplementedError
when I use forward
:
File "/usr/lib/python3.7/site-packages/torch/nn/modules/module.py", line 85, in forward
raise NotImplementedError
NotImplementedError
Any ideas on how to solve it? Or even an alternative way to “create” the name of an attribute of Net
inside forward
? Thank you!
Your code works on my machine. Make sure do don’t have any typos in the definition of forward
and also check the indentation.
Hi @ptrblck, thanks for your prompt response. The truth is that I tried to make it too simple without testing it. The above indeed works, but what I truly need is something like this:
class Net(nn.Module):
def __init__(self, ):
super().__init__()
setattr(self, 'layer', nn.ModuleDict({'a': nn.Linear(8, 16), 'b': nn.Linear(16, 4)}))
def forward(self, x):
return getattr(self, 'layer')(x)
I just tried it and indeed it raises a NotImplementedError
. It seems that the use of nn.ModuleDict
causes the problem, since simple nn modules work fine.
Any ideas on how to fix it? Many thanks
You are currently trying to call the nn.ModuleDict
, which won’t work, as you would have to address one of the layers:
class Net(nn.Module):
def __init__(self, ):
super().__init__()
setattr(self, 'layer', nn.ModuleDict({'a': nn.Linear(8, 16), 'b': nn.Linear(16, 4)}))
def forward(self, x):
x = getattr(getattr(self, 'layer'), 'a')(x)
x = getattr(getattr(self, 'layer'), 'b')(x)
return x
What is your use case, if I may ask?
Hi @ptrblck, many thanks for your response. It should be obvious and I should have figured it out…
The reason for doing so is that I want my network to have a basic backbone sub-network, and a number of branches (heads). Each head is implemented as an nn.ModuleDict
, or even nn.ModuleList
for convenience. During forward, an id
for which head should be used is given (as a string). If you think that’s not a good practice, please let me know.