Hi, I save my model periodically with relevant information. The model is CNN -> transformer encoder-> CTC loss
. The CNN is Pytorch resnet18 which is finetuned as the model collectively trains (the softmax layer is removed and a fc layer of 512 length is added).
model_path = "{}/{}.ckpt".format(self.model_dir, self.steps)
state = {
"steps": self.steps,
"model_state": self.model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"scheduler_state": self.scheduler.state_dict() \
if self.scheduler else None,
}
torch.save(state, model_path)
Now, I need to use the CNN to generate embeddings in new images. Essentially forward pass a new image just through the CNN and get the [1,512] tensor output.
torch.save(state, model_path)
file = '128000.ckpt'
best_model = torch.load(file, map_location=torch.device('cpu'))
state_dict = best_model['model_state']
The state_dict is an OrderedDict
and has weights and biases of everything in the collective model. I can see the layers associated with cnn in the dictionary keys, however, I am not sure how to order them to recreate the CNN.
Any ideas how to do this is appreciated.