Hello!
I am new to PyTorch and I am trying to implement a Bidirectional LSTM model with input sequences of varied length. I wanted to mask the inputs to avoid influencing the gradient calculation with the padding information. Is this the right way to proceed?
class Bi_RNN(nn.Module):
def __init__(self, input_dim1, input_dim2, hidden_dim, batch_size, output_dim, num_layers=2, rnn_type='LSTM'):
super(Bi_RNN, self).__init__()
self.input_dim1 = input_dim1
self.input_dim2 = input_dim2
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.num_layers = num_layers
#Define the initial linear hidden layer
self.init_linear1 = nn.Linear(self.input_dim1, self.input_dim1)
self.init_linear2 = nn.Linear(self.input_dim2, self.input_dim2)
# Define the LSTM layer
self.lstm1 = eval('nn.' + rnn_type)(self.input_dim1, self.hidden_dim, self.num_layers, batch_first=True, bidirectional=True)
self.lstm2 = eval('nn.' + rnn_type)(self.input_dim2, self.hidden_dim, self.num_layers, batch_first=True, bidirectional=True)
self.lstm3 = eval('nn.' + rnn_type)(self.hidden_dim *4, self.hidden_dim, self.num_layers, batch_first=True, bidirectional=True)
# Define the output layer
self.linear = nn.Linear(self.hidden_dim * 2 , output_dim)
self.log_softmax = nn.LogSoftmax(dim=1)
def init_hidden(self):
# This is what we'll initialise our hidden state as
return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim),
torch.zeros(self.num_layers, self.batch_size, self.hidden_dim))
def forward(self, input1, input2, x_lens):
x1_packed = pack_padded_sequence(input1, x_lens, batch_first=True, enforce_sorted=False)
x2_packed = pack_padded_sequence(input2, x_lens, batch_first=True, enforce_sorted=False)
# Forward pass through LSTM layer
lstm_out1, self.hidden1 = self.lstm1(x1_packed)
lstm_out1_padded, output1_lengths = pad_packed_sequence(lstm_out1, batch_first=True)
lstm_out2, self.hidden2 = self.lstm2(x2_packed)
lstm_out2_padded, output2_lengths = pad_packed_sequence(lstm_out2, batch_first=True)
lstm_out3 = torch.cat((lstm_out1_padded, lstm_out2_padded), 2)
lstm_out, self.hidden3 = self.lstm3(lstm_out3)
y_pred = self.log_softmax(self.linear(lstm_out))
return y_pred