Partially freeze embedding layer

I’m implementing a modification of the Seq2Seq model in PyTorch, where I want to partially freeze the embedding layer, e.g. I want to freeze the first N rows and leave the rest unfreezed. What is the best strategy to do this?

the best strategy is to split your embedding layer into two embedding layers, one is for the first N rows, and the second layer containing the rest.

2 Likes

But how do we concatenate them? The following doesn’t work and throws error:

embedding = nn.Embedding(input_size, hidden_size)
new_embedding = nn.Embedding(additional_rows, hidden_size)
embedded = torch.cat((embedding, new_embedding), 0)(word_seq)
2 Likes

If we use two embedding layers, how can we use a single call to do embedding lookup?

Here is a short example on how to split an embedding into two parts:

import torch
import torch.nn as nn
import random

vocab_size = 10
embedding_dim_1 = 2
embedding_dim_2 = 3

embedding_1 = nn.Embedding(vocab_size, embedding_dim_1)
embedding_2 = nn.Embedding(vocab_size, embedding_dim_2)

# Random vector of length 15 consisting of indices 0, ..., 9
x = torch.LongTensor([random.randint(0, 9) for _ in range(15)])
# Adding batch dimension
x = x[None, :]

emb_1 = embedding_1(x)
print(emb_1)
emb_2 = embedding_2(x)
print(emb_2)
# Concatenating embeddings along dimension 2
emb = torch.cat([emb_1, emb_2], dim=2)
print(emb)

1 Like

Here is an example which used in bi-lstm.

embeding_size = 50
pretrain_word_embedding = # a numpy matrix for pretrained embedding
vocab_size = pretrain_word_embedding.size()[0]
freeze_word_embs = nn.Embedding( vocab_size, embedding_size )
freeze_word_embs.weight.data.copy_( torch.from_numpy( pretrain_word_embedding)
freeze_word_embs.weight.requires_grad = False

random_embs =  np.empty( [vocab_size, embedding_size])
scale = np.sqrt( 3.0/embedding_size)
for index in range(vocab_size):
    random_embs[index:]=np.random.uniform(-scale, scale, [ 1, embedding_dim])
unfreeze_word_embs = nn.Embedding( vocab_size, embedding_size)
unfreeze_word_embs.weight.data.copy_(random_embs)

word_inputs = # word id tensor with shape [batch_size, sentence_max_length] 
freeze_boundary=10 # if word_id < freeze_boundary: should pick freeze embedding
batch_size = word_inputs(0)
sent_len  = word_inputs(1)
freeze_embs = freeze_word_embs( word_inputs )
unfreeze_embs = unfreeze_word_embs( word_inputs )
word_embs = []
for i, w_input in enumerate( word_inputs.data ):
    freeze_emb = freeze_embs[ i ]
    unfreeze_emb = unfreeze_embs[ i ]
    word_emb = []
    for j, word_id in enumerate( w_input ):
        if word_id < freeze_bound:
            word_emb.append( freeze_emb[ j ] )
        else:
            word_emb.append( unfreeze_emb[j] )
    word_emb = torch.stack( word_emb )
    word_embs.append( word_emb )
word_embs = torch.stack( word_embs )
# word_embs is the final partial freeze embedding.
2 Likes

I’ve been struggling with this problem as well, and after searching around I figured the first approach of this answer might be the right way to “split to two embedding layers”.


Hope it helps.