I’m implementing an LSTM on audio clips of different sizes. After going through padding/packing, the output of the lstm is:
a.shape = torch.Size([16, 1580, 201])
with (batch, padded sequence, feature)
. I also have a list of the actual lengths of the sequences:
lengths = [1580, 959, 896, 881, 881, 881, 881, 881, 881, 881, 881, 881, 881, 335, 254, 219]
.
What I would like to do is for every element in the batch select the output of the last element in the sequence, and end up with a tensor of shape:
torch.Size([16, 201])
(independent of the variable sequence length the examples have). So far I’ve been using:
torch.cat([a[i:i+1][:,ind-1] for i,ind in enumerate(lengths)], dim=0)
but I was wondering if there’s a proper PyTorch function for such use case?
Thanks.