How to get a matrix top-n in its second axis?

For example, I use numpy:

Then I would like to do it in PyTorch, but it failed, I do not know how to solve it.

You can use torch.topk.
It returns a tuple with the maximum entries (like sort) and their indices (like argsort).

Best regards

Thomas

10 Likes

Hey @tom, torch.topk is significantly slower than torch.sort

I’m running it on a tensor of dimension 300k and taking k=200 out.

sample code -

import torch
x = torch.rand(300000, device='cuda')

%timeit torch.topk(x, 200) # 1000 loops, best of 5: 2.96 ms per loop
%timeit torch.sort(x, descending=True)[0][:200] # 1000 loops, best of 5: 812 µs per loop

Can you please help me here?

Regards
Nikhil

So what happens is that torch.sort (for large tensors if memory serves me well) goes to the Thrust library’s sort which is likely better optimized than PyTorch’s own kernels.
Note that there are memory implications of using sort: The tensor you get from it is a view of the sorted 300k element tensor. Depending on what you do (in particular when involving autograd), this can be a disadvantage.

Best regards

Thomas