How to correctly 'exclude' a module from the state dict of its parent module?

There is a pre-trained model in a bigger model, and the parameters of the pre-trained model are always fixed during training. I want to exclude the parameters of the pre-trained model from the state_dict of the bigger model to save some disk space, since the pre-trained model could have a large number of parameters.

Here is a little example. The model to be trained is ‘MyModel’, and the fixed pre-trained model is the ‘resnet’ in the ‘Encoder’. I wonder how to exclude the parameters of the ‘resnet’ from the state_dict of the ‘Encoder’ model? I tried to override the ‘state_dict’ method and the ‘load_state_dict’ method of the ‘Encoder’, but the ‘load_state_dict’ method is never called when the state of ‘MyModel’ is being loaded.

import torch.nn as nn
from torchvision.models import resnet152

class Encoder(nn.Module):
    def __init__(self):
        self.resnet = resnet152()   # pre-trained model to be excluded

    def forward(self, x):
        return self.resnet(x)

class MyModel(nn.Module):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = ...          # some other network

    def forward(self, x):
        with torch.no_grad():
            feat = self.encoder(x)  # feature extraction only
        y = self.decoder(feat)      # do something else
        return y

So then why don’t you simply remove what you don’t want from the dictionary before saving?