Often modules have structure like this (so that
state_dict() key names are readable):
class MyModule(nn.Module): def __init__(self): super().__init__() self.module1 = nn.Linear(1, 4) self.act1 = nn.ReLU() self.module1 = nn.Linear(4, 8) self.act2 = nn.ReLU() def forward(self, x): x = self.module1(x) x = self.act1(x) x = self.module2(x) x = self.act2(x) return x
Such forward method could go away if we could have some mix of nn.Sequential and nn.ModuleDict, i.e. having a forward method for ModuleDict that would rely on the OrderedDict and implement forward as a Sequential, and still enjoy good module names.
Currently it errors out with a unclear error message (it should rather cry “ModuleDict forward is not supported by default”):
>>> torch.nn.ModuleDict(dict(a = torch.nn.Linear(2, 2), b = torch.nn.Linear(2, 4)))(torch.rand(3, 2)) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/miniconda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) TypeError: forward() takes 1 positional argument but 2 were given
I think starting Python 3.7 (?) insertion order of dictionaries is preserved, and kwarg order is also preserved since https://www.python.org/dev/peps/pep-0468/
This would also simplify static analysis and graph transformations of such models