How to get embeddings from two embedding layers

I would like to freeze only one line of the embedding layer so that the weight of this line would not be updated after each epoch. Some people suggested using two separate embedding layers: one for trainable embeddings and another for the freezing embedding. But I am not sure how to get embeddings from two layers and concatenate them in a fast way. Any help would be apprieciated.

class myModel(nn.Module):
    def __init__(self, config):
        self.config = config
        self.trainable_rel_embeddings = nn.Embedding(self.config.relTotal, self.config.dim) # e.g. 200x100, where 200 is the number of trainable elements, 100 is the dimension
        self.freeze_rel_embeddings = nn.Embedding(1, self.config.dim) = self.config.freeze_init_rel_embs
        self.freeze_rel_embeddings.weight.requires_grad = False = self.config.init_rel_embs
    def forward(batch):
        Batch are indexes, e.g. there is one batch with 5 groups:
        tensor([[  3, 200],                                  
        [  3, 200],
        [  3, 200],
        [188, 2],
        [188, 200]])
       #How could I get embedings whose indexes are less than 200 from self.trainable_rel_embeddings and get embeddings whose index is 200 from self.freeze_rel_embeddings?

Given that the Embedding’s forward is literally just

return F.embedding(
            input, self.weight, self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)

you could do this in your forward yourself and pass in the torch.cated weights.
You might even just store the weight as a parameter (or even a buffer) instead of having its own Embedding layer. (Or subclass Embedding to EmbedddingWithFrozen that does all this under the hood.)

There is a slight drawback that you’ll be copying the embedding quite a bit. With more trickery, you could look into reducing that by doing the embedding in two stages, but it is considerably less straightforward.

Some unsolicited style advice:

I would probably write this as

        with torch.no_grad():

Best regards