I try to use LSTMCell to produce results for variable-length sequences, and get multiple predictions by adding a linear layer after it,
I take inspiration from this codebase How to obtain memory states from pack padded sequence - #2 by Fawaz_Sammani, and what I do is as follows,
import torch from torch import nn sequences = torch.LongTensor([[1, 2, 0, 0, 0, 0], [3, 4, 5, 0, 0, 0], [5, 6, 0, 0, 0, 0], [8, 9, 10, 11, 12, 0]]) seq_lengths = torch.LongTensor([2, 3, 2, 5]) class PackedMemory(nn.Module): def __init__(self, vocab_size, embed_dim, decoder_dim): super(PackedMemory, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTMCell(embed_dim, decoder_dim) self.decoder_dim = decoder_dim self.linears = [nn.Linear(decoder_dim, 10), nn.Linear(decoder_dim, 10)] def forward(self, seq, seq_len): batch_size = seq.size(0) sorted_lengths, sort_indices = torch.sort(seq_len, descending=True) sorted_sequences = seq[sort_indices] hidden_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim) memory_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim) final_hidden = torch.zeros(batch_size, self.decoder_dim) # 这个用于存储 hidden embeddings = self.embedding(sorted_sequences) h,c = [torch.zeros(batch_size, self.decoder_dim), torch.zeros(batch_size, self.decoder_dim)] res =  for t in range(11): if t >= max(sorted_lengths): # TODO for some short sequence pass else: batch_size_t = sum([l > t for l in sorted_lengths]) print(batch_size_t) h, c = self.lstm(embeddings[:batch_size_t, t, :], (h[:batch_size_t], c[:batch_size_t])) hidden_states[:batch_size_t, t, :] = h memory_states[:batch_size_t, t, :] = c final_hidden[:batch_size_t] = h.clone() if t%5 == 0: res.append(self.linears[t//10](final_hidden)) return res packed_memory = PackedMemory(13,512,512) res = packed_memory(sequences, seq_lengths) print(res)
This code is runnable, but it seems the learned results are not as I imagined.
I guess the error may occur in the line marked as TODO. Does anyone have some suggestions?