MultiheadAttention after LSTM returns the same output for all input

Hi,

I would like to use MultiheadAttention as self-attention after applying LSTM on a single sequence.

import torch

# shape: (sequence length, batch size, embedding dimension)
inp = torch.randn(5, 3, 10)
lstm = torch.nn.LSTM(input_size=10, hidden_size=10, num_layers=2)
self_attn = torch.nn.MultiheadAttention(embed_dim=10, num_heads=2)

x, _ = lstm(input=inp)
# x is query, key, value at the same time
out, weight = self_attn(query=x, key=x, value=x)

out
tensor([[[ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273]],

        [[ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273]],

        [[ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273]],

        [[ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273]],

        [[ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273]]], grad_fn=<AddBackward0>)

# shape: (batch size, sequence length, embedding dimension)
out.transpose(0, 1)
tensor([[[ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192],
         [ 0.0137,  0.0457, -0.0169, -0.0393, -0.0214,  0.0162,  0.0534,
           0.0202,  0.0519, -0.0192]],

        [[ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242],
         [ 0.0116,  0.0359, -0.0044, -0.0320, -0.0294,  0.0175,  0.0573,
           0.0373,  0.0316, -0.0242]],

        [[ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273],
         [ 0.0128,  0.0410, -0.0083, -0.0347, -0.0234,  0.0136,  0.0642,
           0.0148,  0.0425, -0.0273]]], grad_fn=<TransposeBackward0>)

As we can see, every token in the input sequences has the same attention output. That turns out no token was actually attended to.

I have tried using attn_mask and key_padding_mask but it’s no use.

Can someone explain what happened? And please correct me if I am applying them the wrong way.

Thanks!

I tried to execute the steps as shown in the source code, but I don’t have the same results as them (I must have forgotten something):

Let’s start by doing it directly:

import torch
import torch.nn.functional as F

torch.manual_seed(0)

seq_len, bsz, embedding_dim = 3, 1, 6
src_len = seq_len
tgt_len = seq_len

num_heads = 1
hidden_size = num_heads*4

inp = torch.randn(seq_len, bsz, embedding_dim)
lstm = torch.nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=2)

embed_dim = hidden_size
self_attn = torch.nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)

x, _ = lstm(input=inp) # seq_len x bsz x hidden_size
"""
tensor([[[-0.0383,  0.2014,  0.0948,  0.0756]],

        [[ 0.0603,  0.2379,  0.1403,  0.2036]],

        [[ 0.1172,  0.2602,  0.1392,  0.2943]]], grad_fn=<StackBackward>)
"""

out, weight = self_attn(query=x, key=x, value=x) # seq_len x bsz x hidden_size, ...
"""
tensor([[[ 0.0262, -0.0770,  0.1162,  0.1179]],

        [[ 0.0262, -0.0770,  0.1162,  0.1179]],

        [[ 0.0262, -0.0770,  0.1162,  0.1179]]], grad_fn=<AddBackward0>)

and

tensor([[[0.3330, 0.3334, 0.3336],
         [0.3327, 0.3334, 0.3339],
         [0.3325, 0.3335, 0.3340]]], grad_fn=<DivBackward0>)
"""

Now, if we want to do it ourselves, we can start by noticing that :

self_attn.q_proj_weight, self_attn.k_proj_weight, self_attn.v_proj_weight
"""
(None, None, None)
"""

self_attn.in_proj_weight
"""
tensor([[ 0.5861,  0.4650, -0.2232,  0.3442],
        [-0.3479, -0.0960,  0.5200,  0.0253],
        [-0.4331, -0.2047, -0.1662, -0.1181],
        [ 0.0586,  0.5663,  0.0328, -0.3781],
        [ 0.0314,  0.2936,  0.3038, -0.5597],
        [-0.1096, -0.4551, -0.2613,  0.2206],
        [-0.4349,  0.2276,  0.5198,  0.0402],
        [-0.4081, -0.2194,  0.1337, -0.4668],
        [ 0.3042, -0.5560, -0.5887, -0.5950],
        [-0.1242,  0.4118, -0.5796,  0.5090],
        [-0.2450,  0.1794,  0.0279, -0.5522],
        [ 0.5079,  0.3297,  0.6087,  0.3094]], requires_grad=True)
"""

So, as here, we have (also, in the forward method of MultiheadAttention):

use_separate_proj_weight = False
if not use_separate_proj_weight:
        #q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
        q, k, v = F.linear(input=x, weight=self_attn.in_proj_weight, bias=self_attn.in_proj_bias).chunk(3, dim=-1)
"""
(tensor([[[ 0.0761,  0.0452, -0.0493,  0.0864]],
 
         [[ 0.1847,  0.0343, -0.1222,  0.0659]],
 
         [[ 0.2600,  0.0141, -0.1620,  0.0475]]], grad_fn=<SplitBackward>),
 tensor([[[ 0.0444, -0.0956,  0.1148, -0.0512]],
 
         [[ 0.0004, -0.1066,  0.1091, -0.1531]],
 
         [[-0.0423, -0.1027,  0.0925, -0.2237]]], grad_fn=<SplitBackward>),
 tensor([[[-0.2244,  0.0712,  0.0064,  0.1280]],
 
         [[-0.3177,  0.1128, -0.0806,  0.2575]],
 
         [[-0.3661,  0.1617, -0.1407,  0.3212]]], grad_fn=<SplitBackward>))
"""

In our case, attn_mask is None (see here), key_padding_mask is None (see here) et bias_k is None and bias_v is None (see here), so we have :

head_dim = embed_dim // num_heads
scaling = float(head_dim) ** -0.5

q = q * scaling

q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
  k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
  v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)


attn_output_weights = torch.bmm(q, k.transpose(1, 2))
"""
tensor([[[0.3373, 0.3288, 0.3339],
         [0.3328, 0.3401, 0.3271],
         [0.3250, 0.3258, 0.3491]]], grad_fn=<SoftmaxBackward>)

"""

attn_output_weights = F.softmax(attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=0, training=False)
"""
tensor([[[0.3352, 0.3331, 0.3317],
         [0.3355, 0.3331, 0.3314],
         [0.3357, 0.3331, 0.3312]]], grad_fn=<SoftmaxBackward>)
"""

attn_output = torch.bmm(attn_output_weights, v)
"""
tensor([[[-0.3025,  0.1151, -0.0714,  0.2352],
         [-0.3024,  0.1150, -0.0713,  0.2351],
         [-0.3024,  0.1150, -0.0713,  0.2351]]], grad_fn=<BmmBackward0>)
"""

attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, self_attn.out_proj.weight, self_attn.out_proj.bias)
"""
tensor([[[ 0.0262, -0.0770,  0.1162,  0.1179]],

        [[ 0.0262, -0.0770,  0.1162,  0.1179]],

        [[ 0.0262, -0.0770,  0.1162,  0.1179]]], grad_fn=<AddBackward0>)

"""

need_weights = True
if need_weights:
    # average attention weights over heads
    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 
    weights = attn_output_weights.sum(dim=1) / num_heads
"""
tensor([[[0.3352, 0.3331, 0.3317],
          [0.3355, 0.3331, 0.3314],
          [0.3357, 0.3331, 0.3312]]], grad_fn=<DivBackward0>))
"""

It seems to me that there is a problem.

Here is an implementation that could also help you (readapted from XLM/transformer.py at master · facebookresearch/XLM · GitHub), it’s always what I use in my projects :

import torch.nn as nn
class MultiHeadAttention(nn.Module):

    def __init__(self, n_heads, dim, dropout, temperature = None):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.dropout = dropout
        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)

        self.dim_per_head = dim // n_heads # d_k = d_v = d_k
        if temperature is None :
            self.temperature = float(self.dim_per_head) ** 0.5 # math.sqrt(self.dim_per_head) 
        else :
            self.temperature = temperature

    def forward(self, input, mask, kv=None, cache=None, need_weights = True):
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if cache is None else cache['slen'] + qlen
        else:
            klen = kv.size(1)
        assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        n_heads = self.n_heads
        dim_per_head = self.dim_per_head
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x):
            """  projection """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(input))                                          # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))                                      # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k_lin(k))                                          # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))                                          # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]
            cache[self.layer_id] = (k, v)

        q = q / self.temperature                                              # (bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))                           # (bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)               # (bs, n_heads, qlen, klen)
        # to line is to solve :
        # /pytorch/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:19: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead.
        mask = mask.bool()
        scores.masked_fill_(mask, -float('inf'))                              # (bs, n_heads, qlen, klen)

        weights = F.softmax(scores.float(), dim=-1).type_as(scores)           # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)
        context = torch.matmul(weights, v)                                    # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)                                            # (bs, qlen, dim)

        if need_weights :
            return self.out_lin(context), weights 
        else :
            return self.out_lin(context)

Test

self_attn = MultiHeadAttention(n_heads = num_heads, dim = hidden_size, dropout=0)
mask = torch.ones((bsz, seq_len))

mask = torch.ones((bsz, seq_len))
att, w = self_attn(input = x, mask = mask, kv=None, cache=None)

Notebook

Thank you for the detailed reply!

I’ll test it later.

Hey,I have the same problem as you, have you solved it?