Restrict backpropagation in specific indices in nn.Embedding

In nn.Embedding can we provide an option so that specific indices are never updated? Like a mask for example?

This seems like a very specific use case and there is not a clear definition of how this would work. It would probably be easiest just to run loss.backward() and then zero out the particular gradients you don’t want to consider.


Yes currently I am using a weight mask to invalidate the data of the specific indices per batch. Kind of like this:

    def invalidate_embeddings(self, max_ent_id):
        Invalidate the entity embeddings
        :param max_ent_id: int
        vocab_size =
        ids = torch.arange(0, vocab_size)
        mask = ids > max_ent_id
        mask = mask.unsqueeze(1)
        weight =
        weight = weight * mask = weight

Till now it’s good because I kept a block of indices at the beginning of the embedding matrix. It gets a bit trickier if the indices are not contiguous: then we need loops and the training will be slower.

I usually just register a hook to the variable that needs masking:

        def hook(grad):
            mask =  Variable(matrix_of_ones_and_zeros)
            return grad * mask


This way, you can also ensure that unwanted gradients are not propogated during backprop.