I am trying to get the topk for multi-dimensional tensors over the last dimension. The k value is tensor rather than a int. I was wondering if there is
a = torch.rand(2,3,4) # size(2,3,4) b = torch.randint(0,4,(2,3)) #size(2,3) # What I want to do, but slow for i in range(2): for j in range(3): indices = torch.topk(a[i,j,:], b[i,j])
Is there an efficient way to do this instead of for loop?