Efficient way to slice tensors with multiple indices

Hi there,

I have a tensor with shape (N x T X D), representing embeddings of N sentences, and a tensor with shape (N x 2), representing the start and end indices of entities in the N sentences.

Is there a efficient way to slice out the embeddings of the entities and return a new tensor (N x L x D) where L is the length of entities?


You can construct a NumPy array representing all the indices that needs to be included. The shape of this array should be N x L. Then, use this array to retrieve the elements of interest by indexing the tensor.