I try to use LSTMCell to produce results for variable-length sequences, and get multiple predictions by adding a linear layer after it,
I take inspiration from this codebase How to obtain memory states from pack padded sequence - #2 by Fawaz_Sammani, and what I do is as follows,
import torch
from torch import nn
sequences = torch.LongTensor([[1, 2, 0, 0, 0, 0],
[3, 4, 5, 0, 0, 0],
[5, 6, 0, 0, 0, 0],
[8, 9, 10, 11, 12, 0]])
seq_lengths = torch.LongTensor([2, 3, 2, 5])
class PackedMemory(nn.Module):
def __init__(self, vocab_size, embed_dim, decoder_dim):
super(PackedMemory, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTMCell(embed_dim, decoder_dim)
self.decoder_dim = decoder_dim
self.linears = [nn.Linear(decoder_dim, 10), nn.Linear(decoder_dim, 10)]
def forward(self, seq, seq_len):
batch_size = seq.size(0)
sorted_lengths, sort_indices = torch.sort(seq_len, descending=True)
sorted_sequences = seq[sort_indices]
hidden_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim)
memory_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim)
final_hidden = torch.zeros(batch_size, self.decoder_dim) # 这个用于存储 hidden
embeddings = self.embedding(sorted_sequences)
h,c = [torch.zeros(batch_size, self.decoder_dim), torch.zeros(batch_size, self.decoder_dim)]
res = []
for t in range(11):
if t >= max(sorted_lengths): # TODO for some short sequence
pass
else:
batch_size_t = sum([l > t for l in sorted_lengths])
print(batch_size_t)
h, c = self.lstm(embeddings[:batch_size_t, t, :], (h[:batch_size_t], c[:batch_size_t]))
hidden_states[:batch_size_t, t, :] = h
memory_states[:batch_size_t, t, :] = c
final_hidden[:batch_size_t] = h.clone()
if t%5 == 0:
res.append(self.linears[t//10](final_hidden))
return res
packed_memory = PackedMemory(13,512,512)
res = packed_memory(sequences, seq_lengths)
print(res)
This code is runnable, but it seems the learned results are not as I imagined.
I guess the error may occur in the line marked as TODO. Does anyone have some suggestions?