Save/Load weights of specific modules (nn.ModuleDict, etc)

Suppose that we have a simple network like the following:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.a_module_dict = nn.ModuleDict(
            {
                'conv_1_1': nn.Conv2d(3, 64, 3, stride=1, padding=1),
                'conv_1_1-relu': nn.ReLU(inplace=True),
                'conv_1_2': nn.Conv2d(64, 64, 3, stride=1, padding=1),
                'conv_1_2-relu': nn.ReLU(inplace=True),
                'conv_1_2-maxp': nn.MaxPool2d(kernel_size=2, stride=2),
            }
        )
        self.another_module_dict = nn.ModuleDict({...})        
        self.conv = nn.Conv2d(32, 32)

The idea that the network is structured as a set of nn.ModuleDict's, or other “single” layers, like nn.Conv2d.

I define the network as follows:

net = Net()

What I want to do is to save “submodels” that contain the weights for the given modules (nn.ModuleDict, nn.Conv2d, etc), and then at some point to load the aforementioned modules with the corresponding weights; something like this:

net = Net()
# Do stuff with `net`

# TODO: save weights for `net.a_module_dict`, `net.another_module_dict`, and `net.conv`

# Do more stuff with `net`

# TODO: load weights and assign them appropriately to `net.a_module_dict`, `net.another_module_dict`, and `net.conv`

Thank you!

You can acces any module easily and save as many as you want.
Like

net = Net()
Do stuff
torch.save(net.conv_1_1.state_dict())
torch.save(net.whateversubnewtwork.whateversubsubnetwork.state_dict())

Do whatever

net.conv_1_1.load_state_dict(torch.load(weights.pth)) #maybe it was not load_state_dict but something like that, just forgot right now
and so on
1 Like

Thank you for your time!