hi,
i have a large float cuda tensor v = (32, 6, 59536).
applying torch.argsort(v, dim=1, descending=True)
takes 14ms. (argsort)
it is called twice. so, it weights in term of runtime.
is there any way to speedup this op?
the same questions goes for torch.sort
.
here is a weird behavior: runtime of torch.sort
depends on the sorting axis!!!
import numpy as np
import torch
def main(axis=-1):
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cpu = np.random.rand(32, 6, 59536).astype(np.float32)
gpu = torch.tensor(cpu).to(device='cuda:0')
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
np.sort(cpu, axis=axis)[1]
torch.cuda.synchronize()
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print('time cpu: {}'.format(elapsed_time_ms))
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
torch.sort(gpu, dim=axis, descending=True).values[1]
torch.cuda.synchronize()
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print('time gpu: {}'.format(elapsed_time_ms))
if __name__ == '__main__':
for axis in [0, 1, 2]:
print('*************** sorting axis={} *************'.format(axis))
for i in range(4):
print('run {}'.format(i))
main(axis=axis)
output:
*************** sorting axis=0 *************
run 0
time cpu: 272.53204345703125
time gpu: 6.794528007507324
run 1
time cpu: 268.1492614746094
time gpu: 5.733503818511963
run 2
time cpu: 268.2332458496094
time gpu: 5.738399982452393
run 3
time cpu: 271.6305236816406
time gpu: 5.7454400062561035
*************** sorting axis=1 *************
run 0
time cpu: 163.71519470214844
time gpu: 15.351776123046875
run 1
time cpu: 163.94012451171875
time gpu: 15.353407859802246
run 2
time cpu: 165.7310791015625
time gpu: 15.347135543823242
run 3
time cpu: 163.39231872558594
time gpu: 15.351840019226074
*************** sorting axis=2 *************
run 0
time cpu: 817.253173828125
time gpu: 6.221983909606934
run 1
time cpu: 820.2147827148438
time gpu: 5.75600004196167
run 2
time cpu: 817.254638671875
time gpu: 5.731488227844238
run 3
time cpu: 814.9271850585938
time gpu: 5.727200031280518
sorting using extreme axes seems way faster than sorting using inner axes, at least for pytorch. numpy behaves differently.
any explanation?
thanks