Implementing Bahdanau's Attention

Hi guys,

I’m trying to implement the attention mechanism described in this paper.
I have implemented the encoder and the decoder modules (the latter will be called one step at a time when decoding a minibatch of sequences).
As a sanity check, I’m trying to overfit a very small dataset but I’m getting worse results than I do when I use a recurrent decoder without the attention mechanism I implemented. So it’s clear that I’ve made a mistake in my implementation, but I haven’t been able to find it yet. I’ve already had a look at some of the resources available on this topic ([1], [2] or [3]).

Could you please, review the code-snippets below and point out to possible errors?

Thank you very much in advance!

This is the implemented attention module:

class AttentionModule(nn.Module):
    """ This layer calculates the context vectors for the current decoder time step based on the current input, the 
    previous hidden state and the decoder annotations.
    Arguments:
    ----------
    hidden_dim_encoder: int
        Number of units in a hidden cell of the encoder.
    hidden_dim_decoder: int
        Number of units in a hidden cell of the decoder.
    batch_first: boolean
        If True, then the first dimension of the input tensors must be equal to the number examples in the batch.
        Otherwise, the time steps.
    """

    def __init__(self, hidden_dim_encoder, hidden_dim_decoder, batch_first=False):
        super(AttentionModule, self).__init__()
        assert isinstance(hidden_dim_encoder, int) and hidden_dim_encoder > 0, "Invalid value for the hidden layer size" 
        assert isinstance(hidden_dim_decoder, int) and hidden_dim_decoder > 0, "Invalid value for the hidden layer size"

        self._batch_first = batch_first

        self.align = nn.Linear(hidden_dim_encoder + hidden_dim_decoder, hidden_dim_decoder)
        self.v_a = nn.Parameter(torch.randn(hidden_dim_decoder))

    def forward(self, s_tm1, encoder_annotations_masked, src_mask):
        """ Forward pass of the layer.
        Notation:
        ---------
        N: Number of examples.
        Tx: Length of the input sequence.
        H_dec: Number of hidden units in the decoder.

        Arguments:
        ----------
        s_tm1: torch.FloatTensor
            Previous hidden state of the decoder. Shape (N, 1, H_dec) if batch_first is True, otherwise (1, N, H_dec)
        encoder_annotations_masked: torch.FloatTensor
            Masked annotations produced by the encoder. Shape (N, Tx, H_enc) if batch_first is True, otherwise 
            (Tx, N, H_enc).
        src_mask: torch.LongTensor
            Minibatch of source sequence masks (binary matrix) of size (N, Tx) given that batch_first is True.
            Otherwise (Tx, N).
        Returns:
        -------
        context_vectors: torch.FloatTensor
            Context vectors of shape (N, 1, H_enc) if batch_first is True, otherwise (1, N, H_enc).
        attn_weights_masked: torch.FloatTensor
            Attention weights of shape (N, Tx, 1) if batch_first is True, otherwise (Tx, N, 1).
        """
        # extract necessary dimensions and expand s_tm1 to a size that matches the encoder time steps
        if self._batch_first:
            N, Tx = src_mask.size()
            s_tm1_expanded = s_tm1.expand(-1, encoder_annotations_masked.size(0), -1)
        else:
            Tx, N = src_mask.size()
            s_tm1_expanded = s_tm1.expand(encoder_annotations_masked.size(0), -1, -1)
        # concatenate with the encoder annotations
        s_tm1_annotations_concat = torch.cat(
            (s_tm1_expanded, encoder_annotations_masked), dim=2
        )
        # calculate alignment scores and propagate through a linear layer
        out = torch.tanh(self.align(s_tm1_annotations_concat))
        # dot product with v_a
        alignment_scores = torch.matmul(out, self.v_a.t())
        # calculate weights, mask them, apply on annotations and calculate context vectors
        if self._batch_first:
            attn_weights = F.softmax(alignment_scores, dim=1)
            attn_weights_masked = attn_weights * src_mask
            attn_weights_masked = attn_weights_masked.unsqueeze(2)
            context_vectors = torch.sum(attn_weights_masked * encoder_annotations_masked, dim=1)
            context_vectors = context_vectors.unsqueeze(1)
        else:
            attn_weights = F.softmax(alignment_scores, dim=0)
            attn_weights_masked = attn_weights * src_mask
            attn_weights_masked = attn_weights_masked.unsqueeze(2)
            context_vectors = torch.sum(attn_weights_masked * encoder_annotations_masked, dim=0)
            context_vectors = context_vectors.unsqueeze(0)
        
        return context_vectors, attn_weights_masked

This is the forward function of the recurrent decoder:

def forward(self, x_input, s_tm1, encoder_annotations_masked, source_mask, target_mask):
    """ Forward function of the decoder. This will be called one step at a time.
    Notation:
    ---------
    N: number of examples in the minibatch
    E: embedding dimensionality
    H_dec: number of hidden units in a decoder RNN cell
    H_enc: number of hidden units in a encoder RNN cell
    Tx: length of the source sequence. Since this module will be called in every time step 
    the value will be equal to 1.
    Ty: length of the target sequence. Since this module will be called in every time step 
    the value will be equal to 1.
    Arguments:
    ----------
    x_input: torch.LongTensor
        Minibatch of input sequences of size (N, Tx), if batch_first is True, otherwise (Tx, N).
    s_tm1: tuple of torch.FloatTensors
        Tuple of tensors, where the first item corresponds to the hidden state and, if the rnn_type is "lstm", the 
        second to the cell state.
        Each tensor is of the standard size (NumLayers*NumDirections, N, H_dec)
    encoder_annotations_masked: torch.FloatTensor
        Annotations produced by the encoder of size (N, Tx, H_enc) if batch_first is True, otherwise (Tx, N, H_enc)
    source_mask: torch.LongTensor
        Minibatch of sequence masks (binary matrix) for the source sequence of size (N, Tx) if batch_first is True.
        Otherwise (Tx, N).
    target_mask: torch.LongTensor
        Minibatch of sequence masks (binary matrix) for the target sequence of size (N, Ty) if batch_first is True.
        Otherwise (Ty, N).
    Returns:
    --------
    out: torch.FloatTensor
        LogSoftmaxed scores produced at time step t.
    s_t: tuple
        Tuple that contains the latest hidden state and - if the decoder rnn_type is "lstm" - the cell state 
        represented as torch.FloatTensor. These tensors are of the same shape as described in the documentation.
    attn_weights: torch.FloatTensor
        Attention weights of shape (N, Tx, 1) if batch_first is True, otherwise (Tx, N, 1).
    """
    if self._batch_first:
        N, Tx = x_input.size()
        Ty = target_mask.size()[1]
    else:
        Tx, N = x_input.size()
        Ty = target_mask.size()[0]
    # embedding
    x_embedded = self._embedding(x_input)
    x_embedded = self._dropout_embedding(x_embedded)
    # context vectors
    context_vectors, attn_weights = self.attend(s_tm1[0], encoder_annotations_masked, source_mask)
    # concat embeddings and context vectors
    x_rnn_inp = torch.cat((x_embedded, context_vectors), dim = 2)
    # rnn
    s_all, s_t = self._rnn(x_rnn_inp, s_tm1)
    # softmaxed scores
    out = F.softmax(self._fc_scores(s_t[0]), dim=1)
    return out, s_t, attn_weights

I’m rather sure that the PyTorch Seq2Seq Tutorial implements the Bahdanau attention.

Thank you! I can’t believe I missed that…