Hello,
In my case I have a saved model named best_model.pkl
, which inherently contains multiple models belonging to pose
, depth
and other models. I was looking to use the existing best_model.pkl
to extract the individual model and save them as separate .pth
files.
I am doing something on this line:
model_to_save = torch.load('best_model.pkl')
print (model_to_save.keys())
which gives the following output:
dict_keys(['epoch', 'model_state', 'optimizer_state', 'scheduler_state', 'best_iou'])
Of which I select the model_state
, which has each of the individual models. Now I am looking to save each of the individual models by the following code.
models = ["depth", "pose_encoder", "pose", "encoder"]
for model_name in models:
save_path = os.path.join("{}.pth".format(model_name))
to_save = model_to_save['model_state']['models'][model_name].state_dict()
torch.save(to_save, save_path)
But I receive following error:
File "save.py", line 10, in save_monodepth_models
to_save = model_to_save['model_state']['models'][model_name].state_dict()
KeyError: 'models'
I see that type(model_to_save['model_state'])
is an <class 'collections.OrderedDict'>
and keys inside the model_to_save['model_state']
is:
odict_keys(['models.encoder.encoder.conv1.weight', 'models.encoder.encoder.bn1.weight', 'models.encoder.encoder.bn1.bias', 'models.encoder.encoder.bn1.running_mean', 'models.enco
der.encoder.bn1.running_var', 'models.encoder.encoder.bn1.num_batches_tracked', 'models.encoder.encoder.layer1.0.conv1.weight', 'models.encoder.encoder.layer1.0.bn1.weight', 'mod
els.encoder.encoder.layer1.0.bn1.bias', 'models.encoder.encoder.layer1.0.bn1.running_mean', 'models.encoder.encoder.layer1.0.bn1.running_var', 'models.encoder.encoder.layer1.0.bn
1.num_batches_tracked', 'models.encoder.encoder.layer1.0.conv2.weight', 'models.encoder.encoder.layer1.0.bn2.weight', 'models.encoder.encoder.layer1.0.bn2.bias', 'models.encoder.
encoder.layer1.0.bn2.running_mean', 'models.encoder.encoder.layer1.0.bn2.running_var', 'models.encoder.encoder.layer1.0.bn2.num_batches_tracked', 'models.encoder.encoder.layer1.0
.conv3.weight', 'models.encoder.encoder.layer1.0.bn3.weight', 'models.encoder.encoder.layer1.0.bn3.bias', 'models.encoder.encoder.layer1.0.bn3.running_mean', 'models.encoder.enco
der.layer1.0.bn3.running_var',................
Not exactly sure how should I handle this case. Any suggestion would be helpful!
Regards!