In my model, there is a fixed and pretrained sub-module which I don’t want to store it in state_dict() for smaller checkpoint size.
class MyModel(nn.Module):
def __init__(self, ref_model):
super().__init__()
self.fc = nn.Linear(3,4) # trainable parameters
self.ref_m = ref_model # fixed parameters
for m in self.ref_m.parameters():
m.requires_grad=False
ref_model=...
model = MyModel(ref_model)
# how to exclude ref_model from state_dict()?
checkpoint = model.state_dict()
You could manually delete the unwanted ref_model parameters and buffers directly as seen here:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3,4) # trainable parameters
self.ref_m = nn.Linear(10, 10) # fixed parameters
for m in self.ref_m.parameters():
m.requires_grad=False
model = MyModel()
checkpoint = model.state_dict()
keys_to_delete = []
for key in checkpoint:
if 'ref_m' in key:
keys_to_delete.append(key)
for key in keys_to_delete:
del checkpoint[key]
print(checkpoint)
Note however, that you won’t be able to use model.load_state_dict(checkpoint) with the default setting of strict=True anymore since obviously keys are missing. In this case you would have to use strict=False and make sure that only the expected ref_m keys are missing.
I assume ref_model is an entire model and not just buffers, as you could use self.register_buffer with persistent=False in the latter case and could avoid manipulating the state_dict.
Thanks @ptrblck, my current workaround is just as you suggested. I wonder if it’s possible to implement it in a more “elegant” way, i.e. we deal the case inside MyModel, such that the interface to upper invoking could be like as usual.
I also tried overwriting the state_dict and load_state_dict of the model. But it seems these functions is never used if we wrapped the model with DDP or DP.
I don’t know if there is a better approach. While internally state_dict hooks are used, they can change at any time and could thus easily break your code if you depend on them.