Confusion about the attention mechanism

Hello everybody,

I’ve read tons of posts and questions around the web on how to best use self-attention with BiLSTM. Now I am feeling a little confused and need some clarification before continuing to study again.

I created a model with a BiLSTM like this:

self.lstm = nn.LSTM(input_size=input_size,  # 140 embedding dim
                             hidden_size=hidden_size,  # 400
                             num_layers=num_layers,  # 2
                             batch_first=batch_first,  # True
                             bidirectional=bidirectional,  # True
                             dropout=dropout)

and I have an attention mechanism done like this:

class SelfAttention(nn.Module):
    def __init__(self, attention_size, batch_first=False, non_linearity="tanh"):
        super(SelfAttention, self).__init__()

        self.batch_first = batch_first
        self.attention_weights = nn.Parameter(torch.FloatTensor(attention_size), requires_grad=True)
        self.softmax = nn.Softmax(dim=-1)

        if non_linearity == "relu":
            self.non_linearity = nn.ReLU()
        else:
            self.non_linearity = nn.Tanh()

        nn.init.uniform_(self.attention_weights.data, -0.005, 0.005)

    def get_mask(self, attentions, lengths):
        """
        Construct mask for padded timesteps, based on lengths
        """
        max_len = max(lengths.data)
        mask = Variable(torch.ones(attentions.size())).detach()

        if attentions.data.is_cuda:
            mask = mask.cuda()

        for i, l in enumerate(lengths.data):  # skip the first sentence
            if l < max_len:
                mask[i, l:] = 0
        return mask

    def forward(self, inputs, lengths):

        ##################################################################
        # STEP 1 - perform dot product
        # of the attention vector and each hidden state
        ##################################################################

        # inputs is a 3D Tensor: batch, len, hidden_size
        # scores is a 2D Tensor: batch, len
        scores = self.non_linearity(inputs.matmul(self.attention_weights))
        scores = self.softmax(scores)

        ##################################################################
        # Step 2 - Masking
        ##################################################################

        # construct a mask, based on the sentence lengths
        mask = self.get_mask(scores, lengths)

        # apply the mask - zero out masked timesteps
        masked_scores = scores * mask

        # re-normalize the masked scores
        _sums = masked_scores.sum(-1, keepdim=True)  # sums per row
        scores = masked_scores.div(_sums)  # divide by row sum

        ##################################################################
        # Step 3 - Weighted sum of hidden states, by the attention scores
        ##################################################################

        # multiply each hidden state with the attention weights
        representations = torch.mul(inputs, scores.unsqueeze(-1).expand_as(inputs))

        return representations, scores

In the forward function of my lstm is it correct to use the output of bilstm as input for attention or should I use h_n or something else?

    def forward(self, x, x_len):
        x = nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True)
        out1, (h_n, c_n) = self.lstm1(x)
        # out1 = (seq_len, batch, num_directions * hidden_size)
        # h_n = (num_layers * num_directions, batch, hidden_size)
        x, lengths = nn.utils.rnn.pad_packed_sequence(out1, batch_first=True)
        x, att1 = self.atten1(x, lengths)  # skip connect
        return x

Thank you so much! :slight_smile:

1 Like

It depends on the context and there are many ways to computer attention.

  1. One of the ways, For language translation where you have an encoder and a decoder
    Here, you can compute the softmax of the dot product of the Encoder and Decoder hidden state. Then multiple of the softmax result with the Encoder hidden state the get a context vector to input into the Decoder for prediction.

In your case, you only have an encoder. In this case you can use out1 and h_n as input for attention.
I did participated in a kaggle project where I faced similar problem and I used both the Output and hidden of LSTM for attention and it worked. You can check implementation here