Confused about my self-attention implementation

Hi everyone,

I tried implementing self-attention mechanism for sequence labeling task (like NER). I fed BiLSTM outputs with (batch_size, seq_length, embed_dim) shape and mask with (batch_size, seq_length) shape into the self-attention. I want my model can focus on the input sentence itself. But the performance is worst than the model without attention.

Here is my code:

class SelfAttention(nn.Module):
    Self-Attention mechanism.
  def __init__(self, emb_dim, kqv_dim, num_heads=1):
    super(SelfAttention, self).__init__()
    self.emb_dim = emb_dim
    self.kqv_dim = kqv_dim
    self.num_heads = num_heads

    self.w_k = nn.Linear(emb_dim, kqv_dim*num_heads, bias=False)
    self.w_q = nn.Linear(emb_dim, kqv_dim*num_heads, bias=False)
    self.w_v = nn.Linear(emb_dim, kqv_dim*num_heads, bias=False)
    self.w_out = nn.Linear(kqv_dim * num_heads, emb_dim)
  def get_mask(self):

  def forward(self, inputs, attention_mask=None):
    b, t, _ = inputs.shape
    e = self.kqv_dim
    h = self.num_heads

    if attention_mask is not None:
      attention_mask = attention_mask.unsqueeze(-1)
      inputs = inputs * attention_mask
    keys = self.w_k(inputs).view(b, t, h, e)
    values = self.w_v(inputs).view(b, t, h, e)
    queries = self.w_q(inputs).view(b, t, h ,e)

    keys = keys.transpose(1, 2).contiguous().view(b*h, t, e)
    queries = queries.transpose(1, 2).contiguous().view(b*h, t, e)
    values = values.transpose(1, 2).contiguous().view(b*h, t, e)

    attn_scores = torch.bmm(queries, keys.transpose(1, 2))
    attn_scores = attn_scores / np.sqrt(e)

    if attention_mask is not None:
      attn_scores = attn_scores * attention_mask
    attn_scores = F.softmax(attn_scores, dim=2)
    weighted_values = torch.bmm(attn_scores, values).view(b, h, t, e)
    weighted_values = weighted_values.transpose(1, 2).contiguous().view(b, t, h*e)
    if attention_mask is not None:
      weighted_values = weighted_values * attention_mask
    outputs = self.w_out(weighted_values)
    return outputs

What’s wrong with my code?
Thank you!

Have you thought of using the Transformers instead?

1 Like


I already thought of using the Transformers; yet, I want to try this way first.
My implementation based on Self-Attention in Transformers. Maybe it does not coincide with my goal.