Indexing tensors of varying dimensions with topk

I am currently using torch.topk to determine the indices of the of a 2D tensor scores which is of size [Batch, N]. I can get the topk values (6000) from scores with torch.gather (or simply from the torch.topk directly).

idx = torch.topk(scores, 6000, dim=1, sorted=True)
scores = torch.gather(scores, dim=1, index=idx) # Output is of size [B, 6000]

My issues comes when I am trying to use the same indices on a 3D tensor bbox which is of size [Batch, N, 4]. How could I use the same indices to get something like below without having to resort to for loops

bbox = torch.gather(bbox, dim=1, index=idx) # Outside of size [B, 6000, 4]

Your code should also work for this shape:

scores = torch.randn(10, 6000)
idx = torch.topk(scores, 6000, dim=1, sorted=True)
out = torch.gather(scores, dim=1, index=idx.indices)
print((torch.sort(scores, dim=1, descending=True).values == out).all())
> tensor(true)

scores = torch.randn(10, 6000, 4)
idx = torch.topk(scores, 6000, dim=1, sorted=True)
out = torch.gather(scores, dim=1, index=idx.indices)
print((torch.sort(scores, dim=1, descending=True).values == out).all())
> tensor(true)

Do you see an error?

I understand that, but wouldn’t that require a reshape or modification of scores to work obtain indices for both the 2D and 3D tensors, rather than a solution that works for both?

One not-entirely-elegant solution I did was

idx_stacked= torch.stack([idx] * bbox.share[-1], dim=idx.dim()
bbox = torch.gather(bbox, dim=1, index=idx_stacked)

Unsure if there a cleaner way to do it.

Hello, can you please help with this example

x = torch.rand(10,5)
y = torch.rand(10,4,5)
top_x,top_inds = x.topk(3,dim=-1)
top_y= y.gather(dim=-1,index=top_inds)

and it does not work.

RuntimeError: Index tensor must have the same number of dimensions as input tensor

Currently, I have to work around this by a for-loop:

s_inds = torch.arange(len(x)).long()
top_y = [y[s_inds,:,top_inds[:,i]] for i in range(3)]
top_y = torch.stack(top_y,dim=2)

My slightly better solution is:

n_samples, topk = len(x), 3
s_inds = torch.arange(len(x)).long()
top_y = y[s_inds.repeat_interleave(topk),:,top_inds.view(-1)] #(n_samples*topk,C)
top_y = top_y.view(n_samples,topk,-1).permute(0,2,1) #(n_samples,C,topk)