How select uneven indices alongside a dimension in batches, and apply pooling?


I have to select certain hidden states from a LSTM outputs, indices of which are different for each sample, and apply pooling on them. A simple way is to write a for loop and select indices for each sample in the batch. But I’d like to use pytorch’s api to minimize python overhead. Is there any option to do this?