Vectorizing an Entity Memory in Entities As Experts

Hello there! This is my first post here so apologies if I’m not following the prescribed etiquette here.

Here is my problem: I am trying to implement an Entity Memory for the Entities As Experts paper in order to augment a BERT model with supervised mention detection. In particular, let us focus on paragraph 2.1 describing the Entity Memory architecture.

The model roughly works this way (we assume to work with BERT-base, with a token embedding size of 768, max sequence length of 512 and 12 attention layers):

  1. Create a BIO and a link predictor heads after the first $l_0$ layers . Assume that after this step we determined that an input sentence $w$ (a sequence of padded/truncated tokens) has k mentions $[m1, m2, …, m_k]$, where a mention is a triplet $(e_{m_i}, s_{m_i}, t_{m_i})$ i.e. a mention id, the start and the end of that span. Suppose that x is the output of the first transformer block made of the first $l_0$ layers.

  2. Let us build a “pseudo” entity embedding this way: for a mention $m_i$ let $h_i = W_f * [x_{m_{s_i}} || x_{t_{m_i}}] obtained by concatenating the first and the last element of the attention output. W_f is a fully connected layer, nothing special. One may think of W_f as a Query matrix and Entity Memory as dot-product attention mechanism.

  3. For each such mention, let us compute the “real” entity embedding for all the the entities known in the dictionary. Such entity embedding is computed as a matrix product E*(e’_i) (with $e’_i$ being a one-hot-encoded version of e_i, which clearly is a scalar). Let us perform the softmax of all the possible dot products across all the EntEmbeds

Then, we compute a weighted average of all the possible products of the entity embeddings across the entity dictionary with the given pseudo embedding $h_{m_i}$.

  1. The loss function is the cross-entropy obtained by getting the sum of the alphas for the given entities. The output instead is calculated by getting this weighted average and sending it to a W_b (a “key” matrix if you think of the model as an Attention mechanism).

I will show the code that illustrates the steps above:

DEVICE = "cuda"

class EntityMemory(Module):
    """
    Entity Memory, as described in the paper
    """

    def __init__(self, embedding_size: int, entity_size: int,
                 entity_embedding_size: int):
        """
        :param embedding_size the size of an embedding. In the EaE paper it is called d_emb, previously as d_k
            (attention_heads * embedding_per_head)
        :param entity_size also known as N in the EaE paper, the maximum number of entities we store
        :param entity_embedding_size also known as d_ent in the EaE paper, the embedding of each entity
        
        """
        super().__init__()
        # pylint:disable=invalid-name
        self.N = entity_size
        self.d_emb = embedding_size
        self.d_ent = entity_embedding_size
        # pylint:disable=invalid-name
        self.W_f = Linear(2*embedding_size, self.d_ent)
        # pylint:disable=invalid-name
        self.W_b = Linear(self.d_ent, embedding_size)
        # pylint:disable=invalid-name
        self.E = Linear(self.N, self.d_ent)
        # TODO: Do not make these hardcoded.
        # The BIO class used to hold these but it got deprecated...
        self.begin = 1
        self.inner = 2
        self.out = 0

    def forward(
        self,
        X,
        bio_output: Optional[torch.LongTensor],
        entities_output: Optional[torch.LongTensor],
        k=100
    ) -> (torch.tensor, torch.tensor):
        """
        :param x the (raw) output of the first transformer block. It has a shape:
                B x N x (embed_size). If not provided no loss is returned
                (which is required during a training stage).
        :param entities_output the detected entities. If not provided no loss is returned
                (which is required during a training stage).
        
        
        :returns a pair (loss, transformer_output). If either of entities_output or bio_output is
                  None loss will be None as well.
        """

        calculate_loss = bio_output is not None and entities_output is not None

        begin_positions = torch.nonzero(bio_output == self.begin)

        y = torch.zeros_like(X).to(DEVICE)

        if calculate_loss:
            loss = torch.zeros((1,)).to(DEVICE)
        else:
            loss = None

        alphas = torch.tensor()

        for pos in begin_positions:
            end_mention = pos[1]
            while end_mention < self.d_emb and bio_output[pos[0], end_mention] == self.inner:
                end_mention += 1
            end_mention -= 1

            first = X[pos[0], pos[1]]
            second = X[pos[0], end_mention]

            mention_span = torch.cat([first, second], 0).to(DEVICE)
            pseudo_entity_embedding = self.W_f(mention_span)  # d_ent

            # During training consider the whole entity dictionary
            if self.train and bio_output is not None and entities_output is not None:
                alpha = F.softmax(self.E.weight.T.matmul(
                    pseudo_entity_embedding), dim=0)
                picked_entity = self.E.weight.matmul(alpha) + self.E.bias

            else:
                # K nearest neighbours
                topk = torch.topk(self.E.weight.T.matmul(
                    pseudo_entity_embedding), k)

                alpha = F.softmax(topk.values, dim=0)

                picked_entity = self.E.weight[:, topk.indices].matmul(alpha) + self.E.bias
            

            y[pos[0], pos[1]] = self.W_b(picked_entity)

            if calculate_loss:
                loss += alpha[entities_output[pos[0], pos[1]]]
        
        return loss, y

Here are the issues:

  1. We are using a for loop to find the end of each mention span for each input row in a batch - this is hardly vectorizable. Is there a way to vectorize that?
  2. The loss function is iteratively computed like this because we need a different alpha for each different mention and corresponding pseudo embedding? I assumed we could use a simple NLLoss but the problem is with the alpha calculation. Clearly, one could create a list of all the different alphas and make a torch.stack() of them, but how does this work with autograd? Does iteratively adding a value work for autograd?

I’d like to make it clear, this forward method works and trains, but at a very slow rate of roughly 1 batch per second (bs=8 on a 2080ti) which makes it difficult to finetune on an appreciable fraction of Wikipedia.

To clarify: the reason why I am performing those matmuls is:

  1. to get every row row of the E matrix to be dot-product’d with the pseudo-entity-embedding
  2. to compute the right average of each “real” entity embedding (aka a row of the E matrix) with the scores computed by alpha. This is furtherly stressed in the case I compute the k-nearest-neighbours during evaluation time.

I realise that using matmuls for linear layers is an antipattern and suggestions are welcome.

PS: the idea of reading the entity memory as an attention mechanism is better described in Facts as Experts by Verga et al.

PS: Is MathJax enabled?

I partially solved the issues this way:

for the first problem, I could still not find a truly satisfying solution, but I managed to remove the second for loop and reconvert it into a series of vectorized operations.

  • avoid computing the gradients for the bio input
  • compute the mention spans
  • use torch.cat() to combine the nonzero() result with the right end
  • compute the gradient on the mentions

The second issue become a bit easier to solve because we already have the mentions in a matrix format. However, some shape calculations were a bit tricker and I had to make use of torch.bmm for the implementation of top-k during inference.

    def forward(
        self,
        X,
        bio_output: Optional[torch.LongTensor],
        entities_output: Optional[torch.LongTensor],
        k=100
    ) -> (torch.tensor, torch.tensor):
        """
        :param x the (raw) output of the first transformer block. It has a shape:
                B x N x (embed_size).
        :param bio_output the output of the bio classifier. If not provided no loss is returned
                (which is required during a training stage).
        :param entities_output the detected entities. If not provided no loss is returned
                (which is required during a training stage).
        
        
        :returns a pair (loss, transformer_output). If either of entities_output or bio_output is
                  None loss will be None as well.
        """


        y = torch.zeros_like(X).to(DEVICE)

        # Disable gradient calculation for BIO outputs, but re-enable them
        # for the span
        with torch.no_grad():

            calculate_loss = bio_output is not None and entities_output is not None
            if calculate_loss:
                loss = torch.zeros((1,)).to(DEVICE)
            else:
                loss = None

            begin_positions = torch.nonzero(bio_output == self.begin)

            # if no mentions are detected skip the entity memory.
            # TODO: would be nice to assess how often this happens.
            if len(begin_positions) == 0:
                return loss, y

            # FIXME: Not really parallelized (we don't have vmap yet...)
            end_positions = torch.tensor([
                self._get_last_mention(bio_output, pos) for pos in begin_positions]).unsqueeze(1).to(DEVICE)

        # Create an array of:
        # 3 dimensions:
        # [ batch_idx1, batch_idx2, batch_idx3... ]
        # [ start_idx1, start_idx2, start_idx3... ]
        # [ end_idx1, end_idx2, end_idx3 ]

        positions = torch.cat([begin_positions, end_positions], 1).T

        first = X[positions[0], positions[1]]
        second = X[positions[0], positions[2]]

        mention_span = torch.cat([first, second], 1).to(DEVICE)
        pseudo_entity_embedding = self.W_f(mention_span) # num_of_mentions x d_ent

        # During training consider the whole entity dictionary
        # Not sure why Pylint thinks self.train is a constant
        # pylint: disable=using-constant-test
        if self.train and bio_output is not None and entities_output is not None:
            alpha = F.softmax(
                pseudo_entity_embedding.matmul(self.E.weight), dim=1)

            # shape: B x d_ent
            picked_entity = self.E(alpha)
        
        else:
            # K nearest neighbours
            topk = torch.topk(self.E.weight.T.matmul(
                pseudo_entity_embedding.T), k, dim=1)

            alpha = F.softmax(topk.values, dim=1)

            # mat1 has size (M x d_ent x k), mat2 has size (M x k x 1)
            # the result has size (M x 256 x 1). Squeeze that out and we've got our
            # entities of size (M x 256)
            picked_entity = torch.bmm(self.E.weight[:, topk.indices].swapaxes(0, 1),
                                      alpha.view((-1, k, 1))).squeeze()

        y[positions[0], positions[1]] = self.W_b(picked_entity)

        # Compared to the original paper we use NLLoss.
        # Gradient-wise this should not change anything
        if calculate_loss:
            loss = self.loss(alpha, entities_output[positions[0], positions[1]])
        else:
            loss = None
        return loss, y

Hi,
did you manage to implement the code for this paper? I would appreciate if you can share. Thanks

Hello. Yes, here it is. The code of the model is under src/models/entities_as_experts.py. The implementation is also mentioned in PapersWithCode.

Note: the latest PyTorch should have introduced vmap but I still have not played with it, and the project is “done” on my end.