How to get a matrix top-n in its second axis?

For example, I use numpy:

Then I would like to do it in PyTorch, but it failed, I do not know how to solve it.

You can use torch.topk.
It returns a tuple with the maximum entries (like sort) and their indices (like argsort).

Best regards



Hey @tom, torch.topk is significantly slower than torch.sort

I’m running it on a tensor of dimension 300k and taking k=200 out.

sample code -

import torch
x = torch.rand(300000, device='cuda')

%timeit torch.topk(x, 200) # 1000 loops, best of 5: 2.96 ms per loop
%timeit torch.sort(x, descending=True)[0][:200] # 1000 loops, best of 5: 812 µs per loop

Can you please help me here?


So what happens is that torch.sort (for large tensors if memory serves me well) goes to the Thrust library’s sort which is likely better optimized than PyTorch’s own kernels.
Note that there are memory implications of using sort: The tensor you get from it is a view of the sorted 300k element tensor. Depending on what you do (in particular when involving autograd), this can be a disadvantage.

Best regards