Bilstm self-attention output dim

Hi everyone,

for several days I have been trying to implement a self-attention mechanism for a bilstm. The code I wrote, looking for some resources on the web, for attention is the following:

class Attention(nn.Module):    

    def __init__(self, hidden_size, batch_first=False):
        super(Attention, self).__init__()

        self.hidden_size = hidden_size
        self.batch_first = batch_first

        self.att_weights = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=True)

        stdv = 1.0 / np.sqrt(self.hidden_size)

        for weight in self.att_weights:
            nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, inputs, lengths):
        batch_size, max_len = inputs.size()[:2]

        # matrix mult
        # apply attention layer
        weights = torch.bmm(inputs,
                            self.att_weights  # (1, hidden_size)
                            .permute(1, 0)  # (hidden_size, 1)
                            .unsqueeze(0)  # (1, hidden_size, 1)
                            .repeat(batch_size, 1, 1)  # (batch_size, hidden_size, 1)
                            )

        attentions = torch.softmax(F.relu(weights.squeeze(2)), dim=-1)

        # create mask based on the sentence lengths
        mask = torch.ones(attentions.size(), requires_grad=True).cuda()
        for i, l in enumerate(lengths):  # skip the first sentence
            if l < max_len:
                mask[i, l:] = 0

        # apply mask and renormalize attention scores (weights)
        masked = attentions * mask
        _sums = masked.sum(-1).unsqueeze(-1)  # sums per row

        attentions = masked.div(_sums)

        # apply attention weights
        weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))

        # get the final fixed vector representations of the sentences
        representations = weighted.sum(1)

        # (batch_size, n_lstm_unit) and (batch_size, sentence_len)
        return representations, attentions 

class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, batch_first, bidirectional, dropout):
        super(MyLSTM, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.lstm1 = nn.LSTM(input_size=input_size,
                             hidden_size=hidden_size,
                             num_layers=num_layers,
                             batch_first=batch_first,
                             bidirectional=bidirectional,
                             dropout=dropout)
        self.atten1 = Attention(hidden_size * 2, batch_first=batch_first)  # 2 is bidrectional

    def forward(self, x, x_len):
        x = nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True)
        out1, (h_n, c_n) = self.lstm1(x)
        x, lengths = nn.utils.rnn.pad_packed_sequence(out1, batch_first=True)
        x, att1 = self.atten1(x, lengths)  # skip connect

        tmp1 = torch.bmm(x.unsqueeze(2), att1.unsqueeze(1))
        tmpp1 = tmp1.transpose(1, 2)
        return tmpp1

What I get after the attention are two tensors of dimension, respectively, (batch_size, n_lstm_unit) and (batch_size, sentence_len). What I necessarily need is to have a single tensor of dimension (batch_size, sentence_len, n_lstm_unit).
As suggested by a forum user, what I could do is something like this:

tmp1 = torch.bmm(x.unsqueeze(2), att1.unsqueeze(1))
tmpp1 = tmp1.transpose(1, 2)

The problem is that in this way it seems that the performance of my parser, after inserting the attention, deteriorates.
I wanted to know if the performance hit could be caused by this torch.bmm(x.unsqueeze(2), att1.unsqueeze(1)) operation to get a single vector of that size or if it is an attention implementation problem. Thanks a lot to everyone.