Recreate CNN from state dict

Hi, I save my model periodically with relevant information. The model is CNN -> transformer encoder-> CTC loss. The CNN is Pytorch resnet18 which is finetuned as the model collectively trains (the softmax layer is removed and a fc layer of 512 length is added).

model_path = "{}/{}.ckpt".format(self.model_dir, self.steps)
state = {
            "steps": self.steps,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict() \
            if self.scheduler else None,
}
torch.save(state, model_path)

Now, I need to use the CNN to generate embeddings in new images. Essentially forward pass a new image just through the CNN and get the [1,512] tensor output.

torch.save(state, model_path)

file = '128000.ckpt'
best_model = torch.load(file, map_location=torch.device('cpu'))
state_dict = best_model['model_state']

The state_dict is an OrderedDict and has weights and biases of everything in the collective model. I can see the layers associated with cnn in the dictionary keys, however, I am not sure how to order them to recreate the CNN.

Any ideas how to do this is appreciated.

Ok. I found a way to do this.

  1. Get state dict of the original resnet18 that was used in training.
  2. Get state dict of the collectively trained model.
  3. Create a new state dict from collectively trained model where keys == keys in step 1.
from torchvision import models
from torch.hub import load_state_dict_from_url
from collections import OrderedDict


# get best model state dict
file = 'best.ckpt'
best_model = torch.load(file, map_location=torch.device('cpu'))
best_state_dict = best_model['model_state']

# get original resnet18
cnn_model = "resnet18"
cnn = models.resnet18(num_classes=num_classes)
cnn_state_dict = load_state_dict_from_url(model_urls[cnn_model])

# create new tuned resnet using best_state_dict
tuned_cnn_state_dict = OrderedDict({k:v for k, v in \
                       best_state_dict.items() if k in cnn_state_dict})

# finally load the weights and biases
cnn.load_state_dict(tuned_cnn_state_dict, strict=False)