Torch.unique is very slow for big number of uniques

import torch
import time

N = int(1e7)
x = torch.LongTensor(N).random_(100000)
x = x.cuda()

torch.cuda.synchronize()
start = time.time()
unique, inv=torch.unique(x, sorted=False, return_inverse=True)
torch.cuda.synchronize()
print(time.time()-start)

which gives 2 minutes on a TitanX and 0.89 s on CPU.
it seems torch.unique runtime grows rapidely with the number of unique values on GPU?

EDIT:
i tried with rusty’s “Pytorch Unique” package and it is much faster on GPU (around 40ms, but much slower in CPU (7.19 s) ). https://github.com/rusty1s/pytorch_unique

1 Like

Thanks for reporting it!
I could reproduce this issue and will have a look at the current implementation.

Small precision, the alternative unique does not return the inverse indices, so probably not comparable to torch.unique.

That’s right. For CPU it uses the numpy implementation, which might be a bit slower.
However, the difference between the current torch implementation between CPU and GPU tensors is quite large, so I would like to see, if something changed internally, since the code was timed before being released.

hello, did you find the reason behind the slow GPU run-time?

I’m still investigating if it’s a performance regression from a certain patch or if it was slow from the beginning for these large 1-dim tensors using my old tests I’ve written for torch.unique.

Hello! May I ask where is the cuda kernel code in pytorch repository for torch.unique?

Sure, you can find it here.

thanks for the link. I have just noticed if you remove torch.cuda.synchronize the time seems faster now (note: i recently updated to torch 1.0) on gpu for torch.unique (like 10ms vs 2 minutes). So is it possible that the problem is not so much the function itself, but a bug somewhere else?

import torch
import time

start = time.time()
N = int(1e7)
Nuniq = int(1e6)
x = torch.LongTensor(N).random_(Nuniq)
x = x.cuda()
runtime = (time.time()-start)
print('runtime allocation: ', runtime)

start = time.time()
for i in range(rounds):
---- uniq, inv = torch.unique(x, return_inverse=True)
uniq[…] = 0 #add this to be sure
inv[…] = 0
runtime = (time.time()-start)/rounds
print('runtime: ', runtime)
m_items = N/1000000.0
mevs = m_items/runtime;
print('mev/s: ', mevs)

Since CUDA calls are asynchronous, you should synchronize it before starting and stopping the timer. Otherwise you’ll be timing the kernel launch.

Ok. So it is just due to the quadratic inverse_kernel_cuda ? (when number of uniques is rougly number of elements)

The issue is being tracked here.
Thanks for reporting it @Etienne_Perot!