Hello there! This is my first post here so apologies if I’m not following the prescribed etiquette here.
Here is my problem: I am trying to implement an Entity Memory for the Entities As Experts paper in order to augment a BERT model with supervised mention detection. In particular, let us focus on paragraph 2.1 describing the Entity Memory architecture.
The model roughly works this way (we assume to work with BERT-base, with a token embedding size of 768, max sequence length of 512 and 12 attention layers):
-
Create a BIO and a link predictor heads after the first $l_0$ layers . Assume that after this step we determined that an input sentence $w$ (a sequence of padded/truncated tokens) has k mentions $[m1, m2, …, m_k]$, where a mention is a triplet $(e_{m_i}, s_{m_i}, t_{m_i})$ i.e. a mention id, the start and the end of that span. Suppose that x is the output of the first transformer block made of the first $l_0$ layers.
-
Let us build a “pseudo” entity embedding this way: for a mention $m_i$ let $h_i = W_f * [x_{m_{s_i}} || x_{t_{m_i}}] obtained by concatenating the first and the last element of the attention output. W_f is a fully connected layer, nothing special. One may think of W_f as a Query matrix and Entity Memory as dot-product attention mechanism.
-
For each such mention, let us compute the “real” entity embedding for all the the entities known in the dictionary. Such entity embedding is computed as a matrix product E*(e’_i) (with $e’_i$ being a one-hot-encoded version of e_i, which clearly is a scalar). Let us perform the softmax of all the possible dot products across all the EntEmbeds
Then, we compute a weighted average of all the possible products of the entity embeddings across the entity dictionary with the given pseudo embedding $h_{m_i}$.
- The loss function is the cross-entropy obtained by getting the sum of the alphas for the given entities. The output instead is calculated by getting this weighted average and sending it to a W_b (a “key” matrix if you think of the model as an Attention mechanism).
I will show the code that illustrates the steps above:
DEVICE = "cuda"
class EntityMemory(Module):
"""
Entity Memory, as described in the paper
"""
def __init__(self, embedding_size: int, entity_size: int,
entity_embedding_size: int):
"""
:param embedding_size the size of an embedding. In the EaE paper it is called d_emb, previously as d_k
(attention_heads * embedding_per_head)
:param entity_size also known as N in the EaE paper, the maximum number of entities we store
:param entity_embedding_size also known as d_ent in the EaE paper, the embedding of each entity
"""
super().__init__()
# pylint:disable=invalid-name
self.N = entity_size
self.d_emb = embedding_size
self.d_ent = entity_embedding_size
# pylint:disable=invalid-name
self.W_f = Linear(2*embedding_size, self.d_ent)
# pylint:disable=invalid-name
self.W_b = Linear(self.d_ent, embedding_size)
# pylint:disable=invalid-name
self.E = Linear(self.N, self.d_ent)
# TODO: Do not make these hardcoded.
# The BIO class used to hold these but it got deprecated...
self.begin = 1
self.inner = 2
self.out = 0
def forward(
self,
X,
bio_output: Optional[torch.LongTensor],
entities_output: Optional[torch.LongTensor],
k=100
) -> (torch.tensor, torch.tensor):
"""
:param x the (raw) output of the first transformer block. It has a shape:
B x N x (embed_size). If not provided no loss is returned
(which is required during a training stage).
:param entities_output the detected entities. If not provided no loss is returned
(which is required during a training stage).
:returns a pair (loss, transformer_output). If either of entities_output or bio_output is
None loss will be None as well.
"""
calculate_loss = bio_output is not None and entities_output is not None
begin_positions = torch.nonzero(bio_output == self.begin)
y = torch.zeros_like(X).to(DEVICE)
if calculate_loss:
loss = torch.zeros((1,)).to(DEVICE)
else:
loss = None
alphas = torch.tensor()
for pos in begin_positions:
end_mention = pos[1]
while end_mention < self.d_emb and bio_output[pos[0], end_mention] == self.inner:
end_mention += 1
end_mention -= 1
first = X[pos[0], pos[1]]
second = X[pos[0], end_mention]
mention_span = torch.cat([first, second], 0).to(DEVICE)
pseudo_entity_embedding = self.W_f(mention_span) # d_ent
# During training consider the whole entity dictionary
if self.train and bio_output is not None and entities_output is not None:
alpha = F.softmax(self.E.weight.T.matmul(
pseudo_entity_embedding), dim=0)
picked_entity = self.E.weight.matmul(alpha) + self.E.bias
else:
# K nearest neighbours
topk = torch.topk(self.E.weight.T.matmul(
pseudo_entity_embedding), k)
alpha = F.softmax(topk.values, dim=0)
picked_entity = self.E.weight[:, topk.indices].matmul(alpha) + self.E.bias
y[pos[0], pos[1]] = self.W_b(picked_entity)
if calculate_loss:
loss += alpha[entities_output[pos[0], pos[1]]]
return loss, y
Here are the issues:
- We are using a for loop to find the end of each mention span for each input row in a batch - this is hardly vectorizable. Is there a way to vectorize that?
- The loss function is iteratively computed like this because we need a different alpha for each different mention and corresponding pseudo embedding? I assumed we could use a simple NLLoss but the problem is with the alpha calculation. Clearly, one could create a list of all the different alphas and make a
torch.stack()
of them, but how does this work with autograd? Does iteratively adding a value work for autograd?
I’d like to make it clear, this forward method works and trains, but at a very slow rate of roughly 1 batch per second (bs=8 on a 2080ti) which makes it difficult to finetune on an appreciable fraction of Wikipedia.
To clarify: the reason why I am performing those matmuls is:
- to get every row row of the E matrix to be dot-product’d with the pseudo-entity-embedding
- to compute the right average of each “real” entity embedding (aka a row of the E matrix) with the scores computed by alpha. This is furtherly stressed in the case I compute the k-nearest-neighbours during evaluation time.
I realise that using matmuls for linear layers is an antipattern and suggestions are welcome.
PS: the idea of reading the entity memory as an attention mechanism is better described in Facts as Experts by Verga et al.
PS: Is MathJax enabled?