I would like to average the outputs of GRU/LSTM. The input sequences have different lengths, so I use packing. With the following simple code, what is the best/efficient way to get the outputs (output of the RNN, not the hidden states h) and take their mean? Either from the packed output or from the padded output. I can use a loop and the sequence lengths to achieve that, but that would be very slow. I am in search of an efficient matrix solution.
class RNNText(nn.Module):
def __init__(self, vocab_size, word_dim=512, embed_size=512, num_layers=1):
super(RNNText, self).__init__()
self.embed_size = embed_size
self.embed = nn.Embedding(vocab_size, word_dim)
self.rnn = nn.GRU(word_dim, embed_size, num_layers, batch_first=True)
self.fc = nn.Linear(embed_size, embed_size)
def forward(self, x, lengths):
x = self.embed(x)
packed = pack_padded_sequence(x, lengths, batch_first=True)
# Forward propagate RNN
out_packed, h = self.rnn(packed)
# padded = pad_packed_sequence(out_packed, batch_first=True)
out = torch.mean(???)
out = self.fc(out)
out = F.normalize(out)
return out
Thanks Chris.
This was the first thing I thought about, but found it somewhat limiting and more complicated.
I am trying to find a simpler solution, if there is any. I could not think of one, yet.
Transformer models pad to a fixed length, since they have no other choice; but I feel it adversely affects the performance, especially when the sequence lengths vary a lot.
You could sum the padded tensor, and then divide it by the lengths of the sequences.
I suppose padded is shaped [batch_size, seq_len, embed_size]
so padded.sum(dim=1) is shaped [batch_size, embed_size]
and lengths is shaped [batch_size]
out = padded.sum(dim=1).div(lengths.float().unsqueeze(dim=1))
But it is zero-padding, right ? So it doesn’t even count in the sum. And since you divide by the lengths, it ends up the same as averaging over the non-padded part.
I don’t see any other way, except using a loop.
What if we use an embedding layer? We can’t be sure that the embedding for padding character would always be zero. And what if we want to apply another operator such as product (like geometric mean) that 0 will ruin everything?