nn.Embedding can we provide an option so that specific indices are never updated? Like a mask for example?
This seems like a very specific use case and there is not a clear definition of how this would work. It would probably be easiest just to run
loss.backward() and then zero out the particular gradients you don’t want to consider.
Yes currently I am using a weight mask to invalidate the data of the specific indices per batch. Kind of like this:
def invalidate_embeddings(self, max_ent_id): """ Invalidate the entity embeddings :param max_ent_id: int :return: """ vocab_size = self.embedding.weight.data.size(0) ids = torch.arange(0, vocab_size) mask = ids > max_ent_id mask = mask.unsqueeze(1) weight = self.embedding.weight.data weight = weight * mask self.embedding.weight.data = weight
Till now it’s good because I kept a block of indices at the beginning of the embedding matrix. It gets a bit trickier if the indices are not contiguous: then we need loops and the training will be slower.
I usually just register a hook to the variable that needs masking:
def hook(grad): mask = Variable(matrix_of_ones_and_zeros) return grad * mask Var.register_hook(hook)
This way, you can also ensure that unwanted gradients are not propogated during backprop.