# Indexing tensors of varying dimensions with topk

I am currently using `torch.topk` to determine the indices of the of a 2D tensor `scores` which is of size [Batch, N]. I can get the topk values (6000) from `scores` with `torch.gather` (or simply from the `torch.topk` directly).

``````idx = torch.topk(scores, 6000, dim=1, sorted=True)
scores = torch.gather(scores, dim=1, index=idx) # Output is of size [B, 6000]
``````

My issues comes when I am trying to use the same indices on a 3D tensor `bbox` which is of size [Batch, N, 4]. How could I use the same indices to get something like below without having to resort to for loops

``````bbox = torch.gather(bbox, dim=1, index=idx) # Outside of size [B, 6000, 4]
``````

Your code should also work for this shape:

``````scores = torch.randn(10, 6000)
idx = torch.topk(scores, 6000, dim=1, sorted=True)
out = torch.gather(scores, dim=1, index=idx.indices)
print((torch.sort(scores, dim=1, descending=True).values == out).all())
> tensor(true)

scores = torch.randn(10, 6000, 4)
idx = torch.topk(scores, 6000, dim=1, sorted=True)
out = torch.gather(scores, dim=1, index=idx.indices)
print((torch.sort(scores, dim=1, descending=True).values == out).all())
> tensor(true)
``````

Do you see an error?

I understand that, but wouldnâ€™t that require a reshape or modification of scores to work obtain indices for both the 2D and 3D tensors, rather than a solution that works for both?

One not-entirely-elegant solution I did was

``````idx_stacked= torch.stack([idx] * bbox.share[-1], dim=idx.dim()
bbox = torch.gather(bbox, dim=1, index=idx_stacked)
``````

Unsure if there a cleaner way to do it.

``````x = torch.rand(10,5)
y = torch.rand(10,4,5)
top_x,top_inds = x.topk(3,dim=-1)
top_y= y.gather(dim=-1,index=top_inds)
``````

and it does not work.

RuntimeError: Index tensor must have the same number of dimensions as input tensor

Currently, I have to work around this by a for-loop:

``````s_inds = torch.arange(len(x)).long()
top_y = [y[s_inds,:,top_inds[:,i]] for i in range(3)]
top_y = torch.stack(top_y,dim=2)
``````

My slightly better solution is:

``````n_samples, topk = len(x), 3
s_inds = torch.arange(len(x)).long()
top_y = y[s_inds.repeat_interleave(topk),:,top_inds.view(-1)] #(n_samples*topk,C)
top_y = top_y.view(n_samples,topk,-1).permute(0,2,1) #(n_samples,C,topk)
``````