Loading saved model with (hacky) weight sharing resulting in gibberish

I have been working on sequence-to-sequence models and tried many different variants. I saved and loaded the models. They have different levels of performance, but most seem to work within expectation.

Now I’m trying the idea of sharing the embedding matrix and the output linear layer’s matrix. On this forum, I found a way of doing that. It looks like a hack to me, but initially it seems to work. Later I realized saving and loading the model breaks it.

Within the model class, I do the following:

self.embeds = nn.Embedding(vocab_size, embed_len)
self.out_linear = nn.Linear(embed_len, vocab_size, bias=False)
self.out_linear.weight.data = self.embeds.weight.data

During training, this actually achieves the best validation error so far. However, when I attempt to load a saved model, the outputs turn out to be complete gibberish. The code I use to save and load the model is:

torch.save(model.state_dict(), file)
model.load_state_dict(torch.load(file))

I have no reason to believe the error is elsewhere, since the weight sharing is the only change between this model and the previous model that works ok. Can anyone tell me what is wrong? What is the best way to share the matrix between the embedding layer and the linear layer?

Any help would be greatly appreciated!

I figured it out. Basically I should use torch.matmul with self.embeds.weight. Doing this solved the saving / loading problem. The hacky way apparently doesn’t work!

logits is a Variable of shape (Batch, Embedding_Dimension) and self.embeds.weight is a Parameter with shape (Vocabulary_Size, Embedding_Dimension)

logits = torch.matmul(logits, self.embeds.weight.t())