Often modules have structure like this (so that state_dict()
key names are readable):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.module1 = nn.Linear(1, 4)
self.act1 = nn.ReLU()
self.module1 = nn.Linear(4, 8)
self.act2 = nn.ReLU()
def forward(self, x):
x = self.module1(x)
x = self.act1(x)
x = self.module2(x)
x = self.act2(x)
return x
Such forward method could go away if we could have some mix of nn.Sequential and nn.ModuleDict, i.e. having a forward method for ModuleDict that would rely on the OrderedDict and implement forward as a Sequential, and still enjoy good module names.
Currently it errors out with a unclear error message (it should rather cry “ModuleDict forward is not supported by default”):
>>> torch.nn.ModuleDict(dict(a = torch.nn.Linear(2, 2), b = torch.nn.Linear(2, 4)))(torch.rand(3, 2))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/miniconda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() takes 1 positional argument but 2 were given
I think starting Python 3.7 (?) insertion order of dictionaries is preserved, and kwarg order is also preserved since https://www.python.org/dev/peps/pep-0468/
This would also simplify static analysis and graph transformations of such models