Backwards through embedding?

Hi there!

For some reasons I need to compute the gradient of the loss with respect to the input data.

My problem is that my model starts with an embedding layer, which doesn’t support propagating the gradient through it. Indeed, to set requires_true to my input data, it has to be of type float. But the embedding module (nn.Embedding) only supports inputs of type double.

Is there anything I am missing, or the embedding layer definitely stops the back propagation? My idea to make it work is to replace the embedding layer, which performs a lookup, by a matrix multiplication. But I first want to be sure my understanding is correct.

Here is a working dummy code of the situation, but my aim is to have the requires_grad to True:

import torch
import torch.nn as nn

class Seq(nn.Module):
    def __init__(self):
        super(Seq, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 3),
            nn.Sigmoid()
        )
        self.embed = nn.Embedding(5, 10)

    def forward(self, data):
        return self.model(self.embed(data))

model = Seq()

#####
# I want the `requires_grad` to be True!
#####
data = torch.tensor(
    [[0, 1, 4, 3, 1], [1, 0, 4, 3, 0], [4, 2, 3, 1, 4]],
    requires_grad=False, dtype=torch.long)

target = torch.rand([3, 1, 3])
output = model(data)

loss = torch.sum(torch.sqrt((output-target)**2))
loss.backward(retain_graph=True)
1 Like

I’m not sure that’s possible with an Embedding layer, since, as you’ve already explained, you should pass an index tensor containing long values.
If there would be a method to compute gradients of type long (which I’m unaware of), this would mean your indices would change due to the gradients.
Is this assumption correct?
For an NLP use case, the gradient would change the “words” which were encoded using the index. Is there any underlying relation between your indices, i.e. are the indices 2 and 3 nearer to each other than 2 and 100?

Thanks for your response!

No, there is no relation between words and their indices.
The idea behind computing the gradient w.r.t to the input data is to get some sort of an indication of the impact of the data on the model, and so performing a data selection during the training.

Same concern for the same problem. Did you find a fix?

I implemented an embedding module using matrix multiplication instead of lookup.

Here is my class, you may need to adapt it. I had some memory concern when backpragating the gradient, so you can activate it or not using self.requires_grad.

import torch.nn as nn
import torch
from functools import reduce
from operator import mul
from utils import get_logger

"""Implements the EmbeddingMul class
Author: Noémien Kocher
Date: Fall 2018
Unit test: embedding_mul_test.py
"""

logger = None


# A pytorch module can not have a logger as its attrbute, because
# it then cannot be serialized.
def set_logger(alogger):
    global logger
    logger = alogger


class EmbeddingMul(nn.Module):
    """This class implements a custom embedding mudule which uses matrix
    multiplication instead of a lookup. The method works in the functional
    way.
    Note: this class accepts the arguments from the original pytorch module
    but only with values that have no effects, i.e set to False, None or -1.
    """

    def __init__(self, depth, device):
        super(EmbeddingMul, self).__init__()
        # i.e the dictionnary size
        self.depth = depth
        self.device = device
        self.ones = torch.eye(depth, requires_grad=False, device=self.device)
        self._requires_grad = False
        # "oh" means One Hot
        self.last_oh = None
        self.last_weight = None

    @property
    def requires_grad(self):
        return self._requires_grad

    @requires_grad.setter
    def requires_grad(self, value):
        self._requires_grad = value
        logger.info(
            f"(embedding mul) requires_grad set to {self.requires_grad}. ")

    def forward(self, input, weight, padding_idx=None, max_norm=None,
                norm_type=2., scale_grad_by_freq=False, sparse=False):
        """Declares the same arguments as the original pytorch implementation
        but only for backward compatibility. Their values must be set to have
        no effects.
        Args:
            - input: of shape (bptt, bsize)
            - weight: of shape (dict_size, emsize)
        Returns:
            - result: of shape (bptt, bsize, dict_size)
        """
        # ____________________________________________________________________
        # Checks if unsupported argument are used
        if padding_idx != -1:
            raise NotImplementedError(
                f"padding_idx must be -1, not {padding_idx}")
        if max_norm is not None:
            raise NotImplementedError(f"max_norm must be None, not {max_norm}")
        if scale_grad_by_freq:
            raise NotImplementedError(f"scale_grad_by_freq must be False, "
                                      f"not {scale_grad_by_freq}")
        if sparse:
            raise NotImplementedError(f"sparse must be False, not {sparse}")
        # ____________________________________________________________________

        if self.last_oh is not None:
            del self.last_oh
        self.last_oh = self.to_one_hot(input)

        with torch.set_grad_enabled(self.requires_grad):
            result = torch.stack(
                [torch.mm(batch.float(), weight)
                 for batch in self.last_oh], dim=0)
        self.last_weight = weight.clone()
        return result

    def to_one_hot(self, input):
        # Returns a new tensor that doesn't share memory
        result = torch.index_select(
            self.ones, 0, input.view(-1).long()).view(
            input.size()+(self.depth,))
        result.requires_grad = self.requires_grad
        return result

    def __repr__(self):
        return self.__class__.__name__ + "({})".format(self.depth)


if __name__ == "__main__":
    input = torch.tensor([[1, 2, 0], [3, 4, 5]])
    dim = 10
    mod = EmbeddingMul(dim)
    emmatrix = torch.rand(10, 5)
    print(emmatrix)
    output = mod(input, emmatrix, -1)
    print(output)
2 Likes

Thanks a million, very helpful!

hey i have a doubt.Why you gave a require grad condition for the one hot encoding vector.I think it should be always false.

I came across the same problem while trying to implement something where I wanted to use the embedding as part of the loss function. I ended up implementing a new Embedding class that works by one hot encoding the inputs with concept similar to resonance amplitude calculation and then performing a matrix multiplication. My embedding looks like this

class HotEmbedding(torch.nn.Module):
    def __init__(self, max_val, embedding_dim, eps=1e-2):
        super(HotEmbedding, self).__init__()
        self.A = torch.arange(max_val, requires_grad=False)
        self.B = torch.randn((max_val, embedding_dim), requires_grad=True)
        self.eps = eps

    def forward(self, x):
        return 1/((x.unsqueeze(1)**2 - self.A**2)+self.eps) @ self.B

And you can use it as

layer = HotEmbedding(10, 5)
x = torch.tensor([1.,2.,3.,1.,2.,3.], requires_grad=True)
y = layer(x)

x and y look like this

tensor([1., 2., 3., 1., 2., 3.], requires_grad=True)
tensor([[  16.7347, -117.7052,   10.8307,   14.4596,  -14.4507],
        [  55.2739,    8.8376,  -12.5775, -121.5673, -114.9303],
        [ -86.5265,  -92.7613,   84.4617,   68.5347,  203.2270],
        [  16.7347, -117.7052,   10.8307,   14.4596,  -14.4507],
        [  55.2739,    8.8376,  -12.5775, -121.5673, -114.9303],
        [ -86.5265,  -92.7613,   84.4617,   68.5347,  203.2270]],
       grad_fn=<MmBackward>)