How to implement a more general argsort

In torch.sort or torch.argsort, I can specify descending to True or False to get sorted order I want.
How can I specify the sorting by a dynamic comparison order. For example, given following tensor to be sorted:

a = torch.tensor([[2, 1, 1, 2, 2, 2, 1], [3, 1, 1, 2, 2, 2, 1]])
compare = [torch.tensor([2, 1]), torch.tensor([2, 3, 1])]

I want a magic_argsort function give me the following result:

result = magic_argsort(a, compare)
# result is torch.tensor([[0, 3, 4, 5, 1, 2, 6], [3, 4, 5, 0, 1, 2, 6]])

Hello Usami!

I do not believe that pytorch’s sort() supports a custom sort order
or comparator argument.

But you can implement it relatively easily. The idea will be to look up
your values to be sorted in your comparison-order vectors, and then
sort those indices.

Note that in your example use case, your compare is a python list
of 1D tensors, not a 2D tensor. Furthermore, the rows differ in length.
Therefore we will have to use a loop, and not just pure pytorch tensor
operations.

Here is a pytorch 0.3.0 version of your magic_argsort() function
applied to your example:

import torch
torch.__version__

a = torch.LongTensor([[2, 1, 1, 2, 2, 2, 1], [3, 1, 1, 2, 2, 2, 1]])
compare = [torch.LongTensor([2, 1]), torch.LongTensor([2, 3, 1])]

# assumes that dataTensor and compareList are equal-length iterables of 1D tensors
# assumes that values in rows of dataTensor do occur in corresponding rows of compareList
def magic_argsort (dataTensor, compareList):
    r = []
    for  d, c  in zip (dataTensor, compareList):
        mask = d.unsqueeze (1) == c   # look up data values; unsqueeze so c can broadcast
        inds = mask.nonzero()[:,1]    # convert to sort-order indices
        r.append (inds.sort()[1])     # 0.3.0 doesn't have separate argsort function
    return torch.stack (r)            # return as torch tensor

magic_argsort (a, compare)

And here is the output:

>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> a = torch.LongTensor([[2, 1, 1, 2, 2, 2, 1], [3, 1, 1, 2, 2, 2, 1]])
>>> compare = [torch.LongTensor([2, 1]), torch.LongTensor([2, 3, 1])]
>>>
>>> # assumes that dataTensor and compareList are equal-length iterables of 1D tensors
... # assumes that values in rows of dataTensor do occur in corresponding rows of compareList
... def magic_argsort (dataTensor, compareList):
...     r = []
...     for  d, c  in zip (dataTensor, compareList):
...         mask = d.unsqueeze (1) == c   # look up data values; unsqueeze so c can broadcast
...         inds = mask.nonzero()[:,1]    # convert to sort-order indices
...         r.append (inds.sort()[1])     # 0.3.0 doesn't have separate argsort function
...     return torch.stack (r)            # return as torch tensor
...
>>> magic_argsort (a, compare)

    0     3     4     5     1     2     6
    3     4     5     0     1     2     6
[torch.LongTensor of size 2x7]

Best.

K. Frank

1 Like

Thanks K. Frank. I also come up with similar idea without the unsqueeze broadcast trick.
I use this trick in my code now, and speed corresponding code up to 4x times faster.