Regarding updating the weight of a model

In the following I am updating the wights in the resnet101 without a problem

base_ = ResNet(Bottleneck, [3, 4, 23, 3])
resnet_weights = torch.load('resnet101_caffe.pth')
base_.load_state_dict(resnet_weights, strict=False)

if i do
print(resnet_weights.keys())

i will have

odict_keys([‘conv1.weight’, ‘bn1.weight’, ‘bn1.bias’, ‘bn1.running_mean’, ‘bn1.running_var’, ‘layer1.0.conv1.weight’, ‘layer1.0.bn1.weight’, ‘layer1.0.bn1.bias’, ‘layer1.0.bn1.running_mean’, ‘layer1.0.bn1.running_var’, ‘layer1.0.conv2.weight’, ‘layer1.0.bn2.weight’, ‘layer1.0.bn2.bias’, ‘layer1.0.bn2.running_mean’, ‘layer1.0.bn2.running_var’, ‘layer1.0.conv3.weight’, ‘layer1.0.bn3.weight’,
.
.
.

my question is:

if i use

model = models.resnet101(pretrained = True)

which is from pytorch, how i can use the weights in this model to update my model in base_.load_state_dict(resnet_weights, strict=False)
like i did above? is it even possible?

Hi,

If the keys are the same, you can do the following:

model = models.resnet101(pretrained = True)
base_.load_state_dict(model.state_dict(), strict=False)

I see, thank you for your answer.
I have a follow up question, let say my model is model = models.resnet101(pretrained = True) and then I finetune it, and then want to save it and load it again.
Do you suggest to save and load it like:
torch.save(the_model.state_dict(), PATH)

the_model.load_state_dict(torch.load(PATH))

or you have other suggestion?

in link
it says to do
the_model = TheModelClass(*args, **kwargs)
first, but i dont get why and what it means by TheModelClass

Yes what you want to do here is fine.
What it means is that creation of the model and loading of the weights are two different things: You don’t save an nn.Module with all it’s weights. You save the weights on one side with .state_dict() and the module’s info on the other side by saving the arguments to create it.

I see! Thanks

Well, can I save and load an nn.Module with all it’s weights ? if so, how?

Sorry it was not clear “You don’t” means “You should not”.
Saving python class instances like models breaks so easily in so many levels that you have a very high chance to end up with an object that you will never be able to load again
You can do it, but you shouldn’t so I won’t tell you how to do it :smiley:

1 Like

I see, fair enough :smile: