Hi, I have a project where my inputs are of shape (batch_size, max_sentences, max_words, embdding_dim).
I want to take the last hidden state as the encoded representation of the sentence to end up with a matrix of shape: (batch_size, max_sentences, hidden_size)
. How can I encode this using an LSTM?
Can it be something like this:
class LSTMEncoder(nn.Module):
def __init__(self, vocab_size, emb_dim, hidden_size, embed):
super(LSTMEncoder, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.enc_hid_dim = enc_hid_dim
self.embed = nn.Embedding(self.vocab_size, self.emb_dim)
self.lstm = nn.LSTMCell(emb_dim, enc_hid_dim)
def forward(self, seq):
"""
seq: (batch_size, max_sentences, max_words)
"""
batch_size = seq.size(0)
final_hidden_states = torch.zeros(batch_size, max_sentences, self.hidden_size).to(device)
h,c = [torch.zeros(batch_size, self.hidden_size).to(device), torch.zeros(batch_size, self.hidden_size).to(device)]
for step in range(seq.size(1)):
embedding = self.embed(seq[:,step]) # (batch_size, max_words, emb_dim)
for i in range(embedding.size(1)):
h,c = self.lstm(embedding[:, i], (h,c)) # inputs: (batch_size, emb_dim), h: (batch_size, hidden_size)
final_hidden_states[:, step, i] = h
return final_hidden_states