Is it possible use torch.argsort output to index a tensor

Assume I have a model that takes as input a sequence of vectors i.e. [B,N,C] = input.shape with B as batch dimension, N as sequence length and C as latent dim.

Additionaly at some point in my model I get a tensor w, that is of
shape [B,N] that weights each element of my sequence i.e. w.sum(dim=1) == [1., 1., 1., ...] of size [B,]

What the follwing code showes is my attempt to get the top k sequence elements of my input X that have the highest values in w:

X = torch.rand(4,7,128)
w = torch.rand(4,7)
w_sorted = torch.argsort(w, dim = 1, descending = True)
w_trunc = w_sorted[:,:5]

X_topk = torch.empty(X.shape[0], 5, 128)
for i in range(X.shape[0]):
    X_topk[i,:,:] = X[i,w_trunc[i,:],:]  

Is there is a better way to do this without the loop and directly indexing multiple dimensions of X using w_trunc?

Hi Agent!

Yes, you can create appropriate indices and use pytorch tensor
indexing to avoid the loop. (You can also gain a bit of efficiency
by using torch.topk() to perform the typically cheaper partial
sort
rather than the full sort used by torch.argsort().)

Here is an illustrative script:

import torch
print (torch.__version__)

_ = torch.manual_seed (2021)

B = 4
N = 7
C = 128
k = 5

X = torch.rand (B, N, C)
w = torch.rand (B, N)
w_sorted = torch.argsort (w, dim = 1, descending = True)
w_trunc = w_sorted[:, :k]

# method with loop
X_topk = torch.empty(X.shape[0], k, C)
for i in range(X.shape[0]):
    X_topk[i, :, :] = X[i, w_trunc[i, :], :]

# method with no-loop indexing and partial instead of full sort
i0 = torch.arange (B).unsqueeze (-1).unsqueeze (-1).expand (B, k, C)
i1 = torch.topk (w, k, dim = 1).indices.unsqueeze (-1).expand (B, k, C)
i2 = torch.arange (C).expand (B, k, C)

X_topk_B = X[i0, i1, i2]

print ('X_topk_B.equal (X_topk):', X_topk_B.equal (X_topk))

And here is its output:

1.9.0
X_topk_B.equal (X_topk): True

Best.

K. Frank