For now, I have it like this
Still, it’s not clear how to apply attention to lstm outputs
class Encoder(nn.Module):
def __init__(self, hparams):
super(Encoder, self).__init__()
self.conv = ResidualBlock1d(in_channels=hparams.encoder_embedding_dim,
out_channels=hparams.encoder_embedding_dim,
kernel_size=hparams.encoder_kernel_size,
activation=hparams.activation, normtype=hparams.normtype,
num_layers=hparams.encoder_n_convolutions - 1)
self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
int(hparams.encoder_embedding_dim / 2), 1,
batch_first=False, bidirectional=True)
self.attn = nn.MultiheadAttention(hparams.encoder_embedding_dim, 8)
def attn_pad_mask(self, lengths):
max_len = torch.max(lengths).item()
mask = torch.arange(max_len, out=torch.cuda.LongTensor(max_len))[
None, :] >= lengths[:, None]
return mask
def forward(self, x, input_lengths):
x = x.transpose(1, 2)
x = self.conv(x)
x = x.transpose(1, 2).transpose(0, 1)
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=False)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=False)
attn_mask = self.attn_pad_mask(input_lengths)
outputs = self.attn(
outputs, outputs, outputs, key_padding_mask=attn_mask, need_weights=False)[0]
outputs = outputs.transpose(0, 1)
return outputs