The code from documentation will give back an error if you add
model = MyModule()
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
Error:
AttributeError Traceback (most recent call last)
<ipython-input-4-a70ce091060c> in <module>()
----> 1 model = MyModule()
<ipython-input-3-20f44d4e0458> in __init__(self)
2 def __init__(self):
3 super(MyModule, self).__init__()
----> 4 self.choices = nn.ModuleDict({
5 'conv': nn.Conv2d(10, 10, 3),
6 'pool': nn.MaxPool2d(3)
AttributeError: module 'torch.nn' has no attribute 'ModuleDict'