Apply the same function(like torch.topk) to a list of tensors

Hi all,
I am trying to get the topk for multi-dimensional tensors over the last dimension. The k value is tensor rather than a int. I was wondering if there is
For example

a = torch.rand(2,3,4) # size(2,3,4)
b = torch.randint(0,4,(2,3))  #size(2,3)

# What I want to do, but slow
for i in  range(2):
  for j in range(3):
    indices = torch.topk(a[i,j,:], b[i,j])

Is there an efficient way to do this instead of for loop?

Thank you!