# Save/Load weights of specific modules (nn.ModuleDict, etc)

Suppose that we have a simple network like the following:

``````class Net(nn.Module):
def __init__(self):
super().__init__()
self.a_module_dict = nn.ModuleDict(
{
'conv_1_1': nn.Conv2d(3, 64, 3, stride=1, padding=1),
'conv_1_1-relu': nn.ReLU(inplace=True),
'conv_1_2': nn.Conv2d(64, 64, 3, stride=1, padding=1),
'conv_1_2-relu': nn.ReLU(inplace=True),
'conv_1_2-maxp': nn.MaxPool2d(kernel_size=2, stride=2),
}
)
self.another_module_dict = nn.ModuleDict({...})
self.conv = nn.Conv2d(32, 32)
``````

The idea that the network is structured as a set of `nn.ModuleDict`'s, or other “single” layers, like `nn.Conv2d`.

I define the network as follows:

``````net = Net()
``````

What I want to do is to save “submodels” that contain the weights for the given modules (`nn.ModuleDict`, `nn.Conv2d`, etc), and then at some point to load the aforementioned modules with the corresponding weights; something like this:

``````net = Net()
# Do stuff with `net`

# TODO: save weights for `net.a_module_dict`, `net.another_module_dict`, and `net.conv`

# Do more stuff with `net`

# TODO: load weights and assign them appropriately to `net.a_module_dict`, `net.another_module_dict`, and `net.conv`
``````

Thank you!

You can acces any module easily and save as many as you want.
Like

``````net = Net()
Do stuff
torch.save(net.conv_1_1.state_dict())
torch.save(net.whateversubnewtwork.whateversubsubnetwork.state_dict())

Do whatever