Speedup torch.argsort, torch.sort. runtime and sorting axis. how to?

hi,
i have a large float cuda tensor v = (32, 6, 59536).
applying torch.argsort(v, dim=1, descending=True) takes 14ms. (argsort)
it is called twice. so, it weights in term of runtime.
is there any way to speedup this op?

the same questions goes for torch.sort.

here is a weird behavior: runtime of torch.sort depends on the sorting axis!!!

import numpy as np
import torch


def main(axis=-1):
    seed = 0
    torch.manual_seed(seed)
    np.random.seed(seed)

    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    cpu = np.random.rand(32, 6, 59536).astype(np.float32)
    gpu = torch.tensor(cpu).to(device='cuda:0')

    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()

    np.sort(cpu, axis=axis)[1]

    torch.cuda.synchronize()
    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print('time cpu: {}'.format(elapsed_time_ms))

    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()

    torch.sort(gpu, dim=axis, descending=True).values[1]

    torch.cuda.synchronize()
    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print('time gpu: {}'.format(elapsed_time_ms))


if __name__ == '__main__':

    for axis in [0, 1, 2]:
        print('*************** sorting axis={} *************'.format(axis))
        for i in range(4):
            print('run {}'.format(i))
            main(axis=axis)

output:

*************** sorting axis=0 *************
run 0
time cpu: 272.53204345703125
time gpu: 6.794528007507324
run 1
time cpu: 268.1492614746094
time gpu: 5.733503818511963
run 2
time cpu: 268.2332458496094
time gpu: 5.738399982452393
run 3
time cpu: 271.6305236816406
time gpu: 5.7454400062561035
*************** sorting axis=1 *************
run 0
time cpu: 163.71519470214844
time gpu: 15.351776123046875
run 1
time cpu: 163.94012451171875
time gpu: 15.353407859802246
run 2
time cpu: 165.7310791015625
time gpu: 15.347135543823242
run 3
time cpu: 163.39231872558594
time gpu: 15.351840019226074
*************** sorting axis=2 *************
run 0
time cpu: 817.253173828125
time gpu: 6.221983909606934
run 1
time cpu: 820.2147827148438
time gpu: 5.75600004196167
run 2
time cpu: 817.254638671875
time gpu: 5.731488227844238
run 3
time cpu: 814.9271850585938
time gpu: 5.727200031280518

sorting using extreme axes seems way faster than sorting using inner axes, at least for pytorch. numpy behaves differently.
any explanation?

thanks

You could split the tensor along the other dimension and sort chunks in parallel on different GPUs and then concatenate them together later on a single GPU.

sorting using extreme axes seems way faster than sorting using inner axes, at least for pytorch. numpy behaves differently.
any explanation?

This might have to do with contiguous data layout in case of some sorting axes. For example if you are sorting on the last axis, all the data is contiguous.