Transformer: reinitialising embeddings breaks them?

I’m implementing adding new tokens to a pretrained Transformer model. For that I need to reinitialise the embedding vectors (expand them to include new tokens). However, I noticed that my expand_embeddings method, when run, breaks something in the model (loss goes up, predictions go rubbish).

Here’s what it looks like:

    def expand_embeddings(self, new_vocab_size, device):
        # Clone weights from embedding layer
        old_embeddings = self.embed.token_embed.weight.data.clone().transpose(0,1).to(device)

        # TEST: Don't include new tokens, instead simply reinitialise the layers as you would
        # if new tokens were to be added
        new_vocab_size = 32004
        
        # Initialise the embedding layer to the old weights
        self.embed = EmbeddingLayer(new_vocab_size, self.d_model, self.max_length, old_embeddings).to(device)
        # Tie weights with the output layer
        self.decoder.generator = Generator(self.d_model, new_vocab_size, self.embed.token_embed.weight.data).to(device)

Here are the Generator and EmbeddingLayer init’s for reference:

class Generator(nn.Module):
    def __init__(self, d_model, vocab_size, tied_weights):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)
        self.proj.weight.data = tied_weights
    
    def forward(self, x):
        (...)

class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, d_model, max_length, tied_weights=None):
        super(EmbeddingLayer, self).__init__()
        self.token_embed = Embedding(vocab_size, d_model)
        if tied_weights is not None:
            # Must transpose weights since Linear transposes by default but Embedding does not
            self.token_embed.weight.data = tied_weights.clone().transpose(0, 1)
        self.pos_embed = Embedding(max_length, d_model)
        self.vocab_size = vocab_size
    def forward(self, x, pos):
        (...)

When expand_embeddings is not run, the training loss starts at 15.12, when it is run it starts at 93.06. However, I’d expect these values to be identical.