In my model, there is a fixed and pretrained sub-module which I don’t want to store it in state_dict() for smaller checkpoint size.
class MyModel(nn.Module):
def __init__(self, ref_model):
super().__init__()
self.fc = nn.Linear(3,4) # trainable parameters
self.ref_m = ref_model # fixed parameters
for m in self.ref_m.parameters():
m.requires_grad=False
ref_model=...
model = MyModel(ref_model)
# how to exclude ref_model from state_dict()?
checkpoint = model.state_dict()
You could manually delete the unwanted ref_model parameters and buffers directly as seen here:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3,4) # trainable parameters
self.ref_m = nn.Linear(10, 10) # fixed parameters
for m in self.ref_m.parameters():
m.requires_grad=False
model = MyModel()
checkpoint = model.state_dict()
keys_to_delete = []
for key in checkpoint:
if 'ref_m' in key:
keys_to_delete.append(key)
for key in keys_to_delete:
del checkpoint[key]
print(checkpoint)
Note however, that you won’t be able to use model.load_state_dict(checkpoint) with the default setting of strict=True anymore since obviously keys are missing. In this case you would have to use strict=False and make sure that only the expected ref_m keys are missing.
I assume ref_model is an entire model and not just buffers, as you could use self.register_buffer with persistent=False in the latter case and could avoid manipulating the state_dict.
Thanks @ptrblck, my current workaround is just as you suggested. I wonder if it’s possible to implement it in a more “elegant” way, i.e. we deal the case inside MyModel, such that the interface to upper invoking could be like as usual.
I also tried overwriting the state_dict and load_state_dict of the model. But it seems these functions is never used if we wrapped the model with DDP or DP.
I don’t know if there is a better approach. While internally state_dict hooks are used, they can change at any time and could thus easily break your code if you depend on them.
I tried something like this from inside my Model which is a child of torch.nn.Module. It seems to work both when wrapped in DDP and without.
Could you let me know if you see any issues in this approach?
def state_dict(self, destination=None, prefix='', keep_vars=False):
my_state_dict = super().state_dict(destination, prefix, keep_vars)
ret_state_dict = {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} # your logic here
return ret_state_dict