In torch.sort or torch.argsort, I can specify descending to True or False to get sorted order I want.
How can I specify the sorting by a dynamic comparison order. For example, given following tensor to be sorted:
I do not believe that pytorch’s sort() supports a custom sort order
or comparator argument.
But you can implement it relatively easily. The idea will be to look up
your values to be sorted in your comparison-order vectors, and then
sort those indices.
Note that in your example use case, your compare is a python list
of 1D tensors, not a 2D tensor. Furthermore, the rows differ in length.
Therefore we will have to use a loop, and not just pure pytorch tensor
operations.
Here is a pytorch 0.3.0 version of your magic_argsort() function
applied to your example:
import torch
torch.__version__
a = torch.LongTensor([[2, 1, 1, 2, 2, 2, 1], [3, 1, 1, 2, 2, 2, 1]])
compare = [torch.LongTensor([2, 1]), torch.LongTensor([2, 3, 1])]
# assumes that dataTensor and compareList are equal-length iterables of 1D tensors
# assumes that values in rows of dataTensor do occur in corresponding rows of compareList
def magic_argsort (dataTensor, compareList):
r = []
for d, c in zip (dataTensor, compareList):
mask = d.unsqueeze (1) == c # look up data values; unsqueeze so c can broadcast
inds = mask.nonzero()[:,1] # convert to sort-order indices
r.append (inds.sort()[1]) # 0.3.0 doesn't have separate argsort function
return torch.stack (r) # return as torch tensor
magic_argsort (a, compare)
And here is the output:
>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> a = torch.LongTensor([[2, 1, 1, 2, 2, 2, 1], [3, 1, 1, 2, 2, 2, 1]])
>>> compare = [torch.LongTensor([2, 1]), torch.LongTensor([2, 3, 1])]
>>>
>>> # assumes that dataTensor and compareList are equal-length iterables of 1D tensors
... # assumes that values in rows of dataTensor do occur in corresponding rows of compareList
... def magic_argsort (dataTensor, compareList):
... r = []
... for d, c in zip (dataTensor, compareList):
... mask = d.unsqueeze (1) == c # look up data values; unsqueeze so c can broadcast
... inds = mask.nonzero()[:,1] # convert to sort-order indices
... r.append (inds.sort()[1]) # 0.3.0 doesn't have separate argsort function
... return torch.stack (r) # return as torch tensor
...
>>> magic_argsort (a, compare)
0 3 4 5 1 2 6
3 4 5 0 1 2 6
[torch.LongTensor of size 2x7]
Thanks K. Frank. I also come up with similar idea without the unsqueeze broadcast trick.
I use this trick in my code now, and speed corresponding code up to 4x times faster.