You don’t understand the key points.
- Use number indexing for
Sequential
andModuleList
- Use key for
ModuleDict
- Use .module_name otherwise
For example, if I have a network like
import torch.nn as nn
class SubNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return x
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.Sequential(
nn.Sequential(
nn.Conv2d(1, 1, 1),
SubNet(),
),
nn.Conv2d(3, 3, 3)
)
self.conv = nn.Conv2d(4, 4, 4)
def forward(self, x):
return x
net = Net()
print(net)
The output is
Net(
(module): Sequential(
(0): Sequential(
(0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
(1): SubNet(
(conv): Conv2d(2, 2, kernel_size=(2, 2), stride=(1, 1))
(relu): ReLU(inplace=True)
)
)
(1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
)
(conv): Conv2d(4, 4, kernel_size=(4, 4), stride=(1, 1))
)
Then we have
net.module[0][0].weight
→ Conv2d(1, 1, 1)
's weight
net.module[0][1].conv.weight
→ Conv2d(2, 2, 2)
's weight
net.module[1].weight
→ Conv2d(3, 3, 3)
's weight
net.conv.weight
→ Conv2d(4, 4, 4)
's weight