Hello everybody,
I’ve read tons of posts and questions around the web on how to best use self-attention with BiLSTM. Now I am feeling a little confused and need some clarification before continuing to study again.
I created a model with a BiLSTM like this:
self.lstm = nn.LSTM(input_size=input_size, # 140 embedding dim
hidden_size=hidden_size, # 400
num_layers=num_layers, # 2
batch_first=batch_first, # True
bidirectional=bidirectional, # True
dropout=dropout)
and I have an attention mechanism done like this:
class SelfAttention(nn.Module):
def __init__(self, attention_size, batch_first=False, non_linearity="tanh"):
super(SelfAttention, self).__init__()
self.batch_first = batch_first
self.attention_weights = nn.Parameter(torch.FloatTensor(attention_size), requires_grad=True)
self.softmax = nn.Softmax(dim=-1)
if non_linearity == "relu":
self.non_linearity = nn.ReLU()
else:
self.non_linearity = nn.Tanh()
nn.init.uniform_(self.attention_weights.data, -0.005, 0.005)
def get_mask(self, attentions, lengths):
"""
Construct mask for padded timesteps, based on lengths
"""
max_len = max(lengths.data)
mask = Variable(torch.ones(attentions.size())).detach()
if attentions.data.is_cuda:
mask = mask.cuda()
for i, l in enumerate(lengths.data): # skip the first sentence
if l < max_len:
mask[i, l:] = 0
return mask
def forward(self, inputs, lengths):
##################################################################
# STEP 1 - perform dot product
# of the attention vector and each hidden state
##################################################################
# inputs is a 3D Tensor: batch, len, hidden_size
# scores is a 2D Tensor: batch, len
scores = self.non_linearity(inputs.matmul(self.attention_weights))
scores = self.softmax(scores)
##################################################################
# Step 2 - Masking
##################################################################
# construct a mask, based on the sentence lengths
mask = self.get_mask(scores, lengths)
# apply the mask - zero out masked timesteps
masked_scores = scores * mask
# re-normalize the masked scores
_sums = masked_scores.sum(-1, keepdim=True) # sums per row
scores = masked_scores.div(_sums) # divide by row sum
##################################################################
# Step 3 - Weighted sum of hidden states, by the attention scores
##################################################################
# multiply each hidden state with the attention weights
representations = torch.mul(inputs, scores.unsqueeze(-1).expand_as(inputs))
return representations, scores
In the forward function of my lstm is it correct to use the output
of bilstm as input for attention or should I use h_n
or something else?
def forward(self, x, x_len):
x = nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True)
out1, (h_n, c_n) = self.lstm1(x)
# out1 = (seq_len, batch, num_directions * hidden_size)
# h_n = (num_layers * num_directions, batch, hidden_size)
x, lengths = nn.utils.rnn.pad_packed_sequence(out1, batch_first=True)
x, att1 = self.atten1(x, lengths) # skip connect
return x
Thank you so much!