Selecting element on dimension from list of indexes

I’m implementing an LSTM on audio clips of different sizes. After going through padding/packing, the output of the lstm is:

a.shape = torch.Size([16, 1580, 201])

with (batch, padded sequence, feature). I also have a list of the actual lengths of the sequences:

lengths = [1580, 959, 896, 881, 881, 881, 881, 881, 881, 881, 881, 881, 881, 335, 254, 219].

What I would like to do is for every element in the batch select the output of the last element in the sequence, and end up with a tensor of shape:

torch.Size([16, 201])

(independent of the variable sequence length the examples have). So far I’ve been using:
torch.cat([a[i:i+1][:,ind-1] for i,ind in enumerate(lengths)], dim=0)

but I was wondering if there’s a proper PyTorch function for such use case?

Thanks.

The following indexing should work:

x = torch.randn(16, 1580, 201)
idx = torch.tensor(
    [1580, 959, 896, 881, 881, 881, 881, 881, 881, 881, 881, 881, 881, 335, 254, 219]
)
idx = idx - 1  # 0-based index
y = x[torch.arange(x.size(0)), idx]
1 Like

Awesome, that’s exactly what I needed thanks!

1 Like