I want to remove and modify the code of a preteined model and for loading this pretrained model I want to use strict=False
in the state_dict
. Based on the code I’m working with and also based on the pretrained model that I need to use ([TimeSformer](https://github.com/facebookresearch/TimeSformer)
), the TimeSformer
model load its weights from inside itself.but in my previous experiments (using the current codes) I loaded the pretrained weights in another section of the code. In my current code, I cannot find any place to apply strict=False
. How can I apply strict=False
to avoid any mismatch because of differences between modified code of the pretrained model and its weights? Any help would be appreciated.
This is my code that I used for my previous experiments:
class my_model(nn.Module):
def __init__(self, pretrained=False):
super(my_model, self).__init__()
self.featureExtractor =feature_extractor()
if pretrained:
print('Loading weights...')
weight_dict=torch.load(os.path.join('models','vid_class.pt'))
model_dict=self.featureExtractor.state_dict()
list_weight_dict=list(weight_dict.items())
list_model_dict=list(model_dict.items())
for i in range(len(list_model_dict)):
assert list_model_dict[i][1].shape==list_weight_dict[i][1].shape
model_dict[list_model_dict[i][0]].copy_(weight_dict[list_weight_dict[i][0]])
for i in range(len(list_model_dict)):
assert torch.all(torch.eq(model_dict[list_model_dict[i][0]],weight_dict[list_weight_dict[i][0]].to('cpu')))
print('Loading done!')
But when using TimeSformer
, it loads the weights by itself and cannot find the place for applying strict=False
in the state_dict
:
class my_model(nn.Module):
def __init__(self, pretrained=False):
super(my_model, self).__init__()
self.featureExtractor =TimeSformer(img_size=224, num_classes=400, num_frames=8, attention_type='divided_space_time',
pretrained_model='/models/TimeSformer.pyth')
..
..
..