In nn.Embedding
can we provide an option so that specific indices are never updated? Like a mask for example?
This seems like a very specific use case and there is not a clear definition of how this would work. It would probably be easiest just to run loss.backward()
and then zero out the particular gradients you don’t want to consider.
2 Likes
Yes currently I am using a weight mask to invalidate the data of the specific indices per batch. Kind of like this:
def invalidate_embeddings(self, max_ent_id):
"""
Invalidate the entity embeddings
:param max_ent_id: int
:return:
"""
vocab_size = self.embedding.weight.data.size(0)
ids = torch.arange(0, vocab_size)
mask = ids > max_ent_id
mask = mask.unsqueeze(1)
weight = self.embedding.weight.data
weight = weight * mask
self.embedding.weight.data = weight
Till now it’s good because I kept a block of indices at the beginning of the embedding matrix. It gets a bit trickier if the indices are not contiguous: then we need loops and the training will be slower.
I usually just register a hook to the variable that needs masking:
def hook(grad):
mask = Variable(matrix_of_ones_and_zeros)
return grad * mask
Var.register_hook(hook)
This way, you can also ensure that unwanted gradients are not propogated during backprop.
2 Likes