How to customize the module `state_dict` & `load_state_dict`

In my model, there is a fixed and pretrained sub-module which I don’t want to store it in state_dict() for smaller checkpoint size.

class MyModel(nn.Module):
    def __init__(self, ref_model):
        super().__init__()
        self.fc = nn.Linear(3,4)  # trainable parameters
        self.ref_m = ref_model    # fixed parameters
        for m in self.ref_m.parameters():
            m.requires_grad=False

ref_model=...
model = MyModel(ref_model)
# how to exclude ref_model from state_dict()?
checkpoint = model.state_dict()

You could manually delete the unwanted ref_model parameters and buffers directly as seen here:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3,4)  # trainable parameters
        self.ref_m = nn.Linear(10, 10)    # fixed parameters
        for m in self.ref_m.parameters():
            m.requires_grad=False

model = MyModel()
checkpoint = model.state_dict()

keys_to_delete = []
for key in checkpoint:
    if 'ref_m' in key:
        keys_to_delete.append(key)
for key in keys_to_delete:
    del checkpoint[key]
print(checkpoint)

Note however, that you won’t be able to use model.load_state_dict(checkpoint) with the default setting of strict=True anymore since obviously keys are missing. In this case you would have to use strict=False and make sure that only the expected ref_m keys are missing.

I assume ref_model is an entire model and not just buffers, as you could use self.register_buffer with persistent=False in the latter case and could avoid manipulating the state_dict.

Thanks @ptrblck, my current workaround is just as you suggested. I wonder if it’s possible to implement it in a more “elegant” way, i.e. we deal the case inside MyModel, such that the interface to upper invoking could be like as usual.

I also tried overwriting the state_dict and load_state_dict of the model. But it seems these functions is never used if we wrapped the model with DDP or DP.

I don’t know if there is a better approach. While internally state_dict hooks are used, they can change at any time and could thus easily break your code if you depend on them.

I tried something like this from inside my Model which is a child of torch.nn.Module. It seems to work both when wrapped in DDP and without.

Could you let me know if you see any issues in this approach?

def state_dict(self, destination=None, prefix='', keep_vars=False):
    my_state_dict = super().state_dict(destination, prefix, keep_vars)
    ret_state_dict = {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} # your logic here
    return ret_state_dict