Hi all,
I am trying to get the topk for multi-dimensional tensors over the last dimension. The k value is tensor rather than a int. I was wondering if there is
For example
a = torch.rand(2,3,4) # size(2,3,4)
b = torch.randint(0,4,(2,3)) #size(2,3)
# What I want to do, but slow
for i in range(2):
for j in range(3):
indices = torch.topk(a[i,j,:], b[i,j])
Is there an efficient way to do this instead of for loop?
Thank you!