Extract core model from Siamese Network

I have trained a model with the following general architecture:

  1. Model takes two inputs
  2. Runs both of them through an embedding model, which returns some embedding corresponding to both inputs, say, shape a → [batch_size, 748] and b ->[batch_size, 748].
  3. Take the absolute difference, ie torch.abs(a - b), and pass to fully connected layer.
  4. Train model with cross entropy loss.

Now, I want to extract the embedding model such that I can pass one input and get an embedding. How can this be achieved?
The “forward hook” seems to not be applicable here since the number of inputs in my Siamese Network are two and thus I won’t be able to pass a single input and get the embedding.

Thank you.

It’s actually pretty simple if you know how the embedding model is defined.
Initialize a new model with the same configuration of layers as the embedding model.
And say, siameseNetwork is the network with the embedding model within it as a layer, called “embedding_model”. You can find this by print(siameseNetwork). Then do something like:

embedding_model = EmbeddingModel()
embedding_model.load_state_dict(siameseNetwork.embedding_model.state_dict())

Now embedding_model will have the correct weights.