Assume I have a model that takes as input a sequence of vectors i.e. [B,N,C] = input.shape with B as batch dimension, N as sequence length and C as latent dim.
Additionaly at some point in my model I get a tensor w, that is of
shape [B,N] that weights each element of my sequence i.e. w.sum(dim=1) == [1., 1., 1., ...] of size [B,]
What the follwing code showes is my attempt to get the top k sequence elements of my input X that have the highest values in w:
X = torch.rand(4,7,128)
w = torch.rand(4,7)
w_sorted = torch.argsort(w, dim = 1, descending = True)
w_trunc = w_sorted[:,:5]
X_topk = torch.empty(X.shape[0], 5, 128)
for i in range(X.shape[0]):
X_topk[i,:,:] = X[i,w_trunc[i,:],:]
Is there is a better way to do this without the loop and directly indexing multiple dimensions of X using w_trunc?
Yes, you can create appropriate indices and use pytorch tensor
indexing to avoid the loop. (You can also gain a bit of efficiency
by using torch.topk() to perform the typically cheaper partial
sort rather than the full sort used by torch.argsort().)
Here is an illustrative script:
import torch
print (torch.__version__)
_ = torch.manual_seed (2021)
B = 4
N = 7
C = 128
k = 5
X = torch.rand (B, N, C)
w = torch.rand (B, N)
w_sorted = torch.argsort (w, dim = 1, descending = True)
w_trunc = w_sorted[:, :k]
# method with loop
X_topk = torch.empty(X.shape[0], k, C)
for i in range(X.shape[0]):
X_topk[i, :, :] = X[i, w_trunc[i, :], :]
# method with no-loop indexing and partial instead of full sort
i0 = torch.arange (B).unsqueeze (-1).unsqueeze (-1).expand (B, k, C)
i1 = torch.topk (w, k, dim = 1).indices.unsqueeze (-1).expand (B, k, C)
i2 = torch.arange (C).expand (B, k, C)
X_topk_B = X[i0, i1, i2]
print ('X_topk_B.equal (X_topk):', X_topk_B.equal (X_topk))