Restrict backpropagation in specific indices in nn.Embedding

In nn.Embedding can we provide an option so that specific indices are never updated? Like a mask for example?

This seems like a very specific use case and there is not a clear definition of how this would work. It would probably be easiest just to run loss.backward() and then zero out the particular gradients you don’t want to consider.

2 Likes

Yes currently I am using a weight mask to invalidate the data of the specific indices per batch. Kind of like this:

    def invalidate_embeddings(self, max_ent_id):
        """
        Invalidate the entity embeddings
        :param max_ent_id: int
        :return:
        """
        vocab_size = self.embedding.weight.data.size(0)
        ids = torch.arange(0, vocab_size)
        mask = ids > max_ent_id
        mask = mask.unsqueeze(1)
        weight = self.embedding.weight.data
        weight = weight * mask
        self.embedding.weight.data = weight

Till now it’s good because I kept a block of indices at the beginning of the embedding matrix. It gets a bit trickier if the indices are not contiguous: then we need loops and the training will be slower.

I usually just register a hook to the variable that needs masking:

        def hook(grad):
            mask =  Variable(matrix_of_ones_and_zeros)
            return grad * mask

        Var.register_hook(hook)

This way, you can also ensure that unwanted gradients are not propogated during backprop.

2 Likes