Hi guys,
I’m trying to implement the attention mechanism described in this paper.
I have implemented the encoder and the decoder modules (the latter will be called one step at a time when decoding a minibatch of sequences).
As a sanity check, I’m trying to overfit a very small dataset but I’m getting worse results than I do when I use a recurrent decoder without the attention mechanism I implemented. So it’s clear that I’ve made a mistake in my implementation, but I haven’t been able to find it yet. I’ve already had a look at some of the resources available on this topic ([1], [2] or [3]).
Could you please, review the code-snippets below and point out to possible errors?
Thank you very much in advance!
This is the implemented attention module:
class AttentionModule(nn.Module):
""" This layer calculates the context vectors for the current decoder time step based on the current input, the
previous hidden state and the decoder annotations.
Arguments:
----------
hidden_dim_encoder: int
Number of units in a hidden cell of the encoder.
hidden_dim_decoder: int
Number of units in a hidden cell of the decoder.
batch_first: boolean
If True, then the first dimension of the input tensors must be equal to the number examples in the batch.
Otherwise, the time steps.
"""
def __init__(self, hidden_dim_encoder, hidden_dim_decoder, batch_first=False):
super(AttentionModule, self).__init__()
assert isinstance(hidden_dim_encoder, int) and hidden_dim_encoder > 0, "Invalid value for the hidden layer size"
assert isinstance(hidden_dim_decoder, int) and hidden_dim_decoder > 0, "Invalid value for the hidden layer size"
self._batch_first = batch_first
self.align = nn.Linear(hidden_dim_encoder + hidden_dim_decoder, hidden_dim_decoder)
self.v_a = nn.Parameter(torch.randn(hidden_dim_decoder))
def forward(self, s_tm1, encoder_annotations_masked, src_mask):
""" Forward pass of the layer.
Notation:
---------
N: Number of examples.
Tx: Length of the input sequence.
H_dec: Number of hidden units in the decoder.
Arguments:
----------
s_tm1: torch.FloatTensor
Previous hidden state of the decoder. Shape (N, 1, H_dec) if batch_first is True, otherwise (1, N, H_dec)
encoder_annotations_masked: torch.FloatTensor
Masked annotations produced by the encoder. Shape (N, Tx, H_enc) if batch_first is True, otherwise
(Tx, N, H_enc).
src_mask: torch.LongTensor
Minibatch of source sequence masks (binary matrix) of size (N, Tx) given that batch_first is True.
Otherwise (Tx, N).
Returns:
-------
context_vectors: torch.FloatTensor
Context vectors of shape (N, 1, H_enc) if batch_first is True, otherwise (1, N, H_enc).
attn_weights_masked: torch.FloatTensor
Attention weights of shape (N, Tx, 1) if batch_first is True, otherwise (Tx, N, 1).
"""
# extract necessary dimensions and expand s_tm1 to a size that matches the encoder time steps
if self._batch_first:
N, Tx = src_mask.size()
s_tm1_expanded = s_tm1.expand(-1, encoder_annotations_masked.size(0), -1)
else:
Tx, N = src_mask.size()
s_tm1_expanded = s_tm1.expand(encoder_annotations_masked.size(0), -1, -1)
# concatenate with the encoder annotations
s_tm1_annotations_concat = torch.cat(
(s_tm1_expanded, encoder_annotations_masked), dim=2
)
# calculate alignment scores and propagate through a linear layer
out = torch.tanh(self.align(s_tm1_annotations_concat))
# dot product with v_a
alignment_scores = torch.matmul(out, self.v_a.t())
# calculate weights, mask them, apply on annotations and calculate context vectors
if self._batch_first:
attn_weights = F.softmax(alignment_scores, dim=1)
attn_weights_masked = attn_weights * src_mask
attn_weights_masked = attn_weights_masked.unsqueeze(2)
context_vectors = torch.sum(attn_weights_masked * encoder_annotations_masked, dim=1)
context_vectors = context_vectors.unsqueeze(1)
else:
attn_weights = F.softmax(alignment_scores, dim=0)
attn_weights_masked = attn_weights * src_mask
attn_weights_masked = attn_weights_masked.unsqueeze(2)
context_vectors = torch.sum(attn_weights_masked * encoder_annotations_masked, dim=0)
context_vectors = context_vectors.unsqueeze(0)
return context_vectors, attn_weights_masked
This is the forward function of the recurrent decoder:
def forward(self, x_input, s_tm1, encoder_annotations_masked, source_mask, target_mask):
""" Forward function of the decoder. This will be called one step at a time.
Notation:
---------
N: number of examples in the minibatch
E: embedding dimensionality
H_dec: number of hidden units in a decoder RNN cell
H_enc: number of hidden units in a encoder RNN cell
Tx: length of the source sequence. Since this module will be called in every time step
the value will be equal to 1.
Ty: length of the target sequence. Since this module will be called in every time step
the value will be equal to 1.
Arguments:
----------
x_input: torch.LongTensor
Minibatch of input sequences of size (N, Tx), if batch_first is True, otherwise (Tx, N).
s_tm1: tuple of torch.FloatTensors
Tuple of tensors, where the first item corresponds to the hidden state and, if the rnn_type is "lstm", the
second to the cell state.
Each tensor is of the standard size (NumLayers*NumDirections, N, H_dec)
encoder_annotations_masked: torch.FloatTensor
Annotations produced by the encoder of size (N, Tx, H_enc) if batch_first is True, otherwise (Tx, N, H_enc)
source_mask: torch.LongTensor
Minibatch of sequence masks (binary matrix) for the source sequence of size (N, Tx) if batch_first is True.
Otherwise (Tx, N).
target_mask: torch.LongTensor
Minibatch of sequence masks (binary matrix) for the target sequence of size (N, Ty) if batch_first is True.
Otherwise (Ty, N).
Returns:
--------
out: torch.FloatTensor
LogSoftmaxed scores produced at time step t.
s_t: tuple
Tuple that contains the latest hidden state and - if the decoder rnn_type is "lstm" - the cell state
represented as torch.FloatTensor. These tensors are of the same shape as described in the documentation.
attn_weights: torch.FloatTensor
Attention weights of shape (N, Tx, 1) if batch_first is True, otherwise (Tx, N, 1).
"""
if self._batch_first:
N, Tx = x_input.size()
Ty = target_mask.size()[1]
else:
Tx, N = x_input.size()
Ty = target_mask.size()[0]
# embedding
x_embedded = self._embedding(x_input)
x_embedded = self._dropout_embedding(x_embedded)
# context vectors
context_vectors, attn_weights = self.attend(s_tm1[0], encoder_annotations_masked, source_mask)
# concat embeddings and context vectors
x_rnn_inp = torch.cat((x_embedded, context_vectors), dim = 2)
# rnn
s_all, s_t = self._rnn(x_rnn_inp, s_tm1)
# softmaxed scores
out = F.softmax(self._fc_scores(s_t[0]), dim=1)
return out, s_t, attn_weights