Expanding pretrained embedding


On my current project I’m using the google word2vec embedding googlenews-vectors-negative300.bin
However I was surprised that a lot of word in my text are nor referenced in the embedding(like xenophobia, submissive etc).

Firstly, I wanted to know how I can extand a nn.Embedding with new words. I guess I should then activate backpropagation on this part of the embedding for it to be learned.

Secondly, I don’t know why but I need to pass by gensim to load the embedding, Indeed

text_field = data.Field(sequential=True, tokenize=_tokenize_str)
dataset = TabularDataset(
    fields=[('id',None),('content',text_field )],
vectors = vocab.Vectors('/data/GoogleNews-vectors-negative300.bin.gz')
text_field.vocab.set_vectors(vectors.stoi, vectors.vectors, vectors.dim)
embedding = nn.Embedding.from_pretrained(torch.FloatTensor(text_field.vocab.vectors))

does not work, instead I need to do first:

model = gensim.models.KeyedVectors.load_word2vec_format('data/GoogleNews-vectors-negative300.bin.gz', binary=True)
vectors = vocab.Vectors('/content/drive/My Drive/ActNews/data/myGoogleEmbedding.bin') 
text_field.vocab.set_vectors(vectors.stoi, vectors.vectors, vectors.dim)
embedding = nn.Embedding.from_pretrained(torch.FloatTensor(text_field.vocab.vectors))

Best regards,


You could try to concatenate the pretrained weight matrix with a newly initialized tensor to create the new weight matrix with the extended vocabulary.
To keep the pretrained embedding matrix constant, you could register a hook to zero out the gradients of this part of the weight.
Here is a small code snippet to demonstrate this approach:

vocab_size = 2
embedding_dim = 10
emb = nn.Embedding(vocab_size, embedding_dim)

# Add vocab
emb.weight = nn.Parameter(
    torch.cat((emb.weight, torch.randn(2, embedding_dim))))

# Register hook to zero out gradients of pretrained embedding weights
mask = torch.zeros_like(emb.weight)
mask[2:] = 1.
emb.weight.register_hook(lambda grad: grad*mask)

# Training
x = torch.randint(0, 4, (10,))
out = emb(x)

# Should pring zeros in first half

Let me know, if this would work for you.