Hi @ptrblck, thanks for your prompt response. The truth is that I tried to make it too simple without testing it. The above indeed works, but what I truly need is something like this:
class Net(nn.Module):
def __init__(self, ):
super().__init__()
setattr(self, 'layer', nn.ModuleDict({'a': nn.Linear(8, 16), 'b': nn.Linear(16, 4)}))
def forward(self, x):
return getattr(self, 'layer')(x)
I just tried it and indeed it raises a NotImplementedError
. It seems that the use of nn.ModuleDict
causes the problem, since simple nn modules work fine.
Any ideas on how to fix it? Many thanks