Torch.bincount() ~1000x slower on cuda

Hello, hoping for a little insight. I’m writing a loss function that calculates, among other things, entropy of images. I’m using torch.bincount to create a grey-level co-occurance matrix. The function is nearly 1000x slower on cuda when the input tensor contains a large number of zeros. Here’s some toy code to minimally reproduce the problem:

a = torch.randint(0, 6500, (100000,))
b = torch.zeros(100000).long()
for x in range(2700):
    idx = np.random.choice(100000)
    num = np.random.choice(6500)
    b[idx] = num

print('a unique values', len(torch.unique(a)))
print('b unique values', len(torch.unique(b)))
print('a max', a.max())
print('b max', b.max())
print('a # of zeros: ', (a == 0).sum())
print('b # of zeros: ', (b == 0).sum())

print('CPU Bincount:')
for i, x in enumerate([a,b]):
    time1 = time.time()
    torch.bincount(x)
    time2 = time.time()
    print('Tensor {}, bincount time: {:4f} s'.format(i, (time2-time1)))

print('Cuda Bincount:')
for i, x in enumerate([a,b]):
    x = x.cuda()
    torch.cuda.synchronize()
    time1 = time.time()
    torch.bincount(x)
    torch.cuda.synchronize()
    time2 = time.time()
    print('Tensor {}, bincount time: {:4f} s'.format(i, (time2-time1)))

Output:

a unique values 6500
b unique values 2198
a max tensor(6499)
b max tensor(6499)
a # of zeros: tensor(19)
b # of zeros: tensor(97343)
CPU Bincount:
Tensor 0, bincount time: 0.002553 s
Tensor 1, bincount time: 0.002548 s
Cuda Bincount:
Tensor 0, bincount time: 0.001399 s
Tensor 1, bincount time: 1.475042 s

Can anyone shed some light on this? Does this have anything to do with the “non-deterministic” behavior of bincount() on cuda?

Hi Andrew,

yes, in a way they’re related.
Bincount seems to eventually reduce to kernelHistogram1D in SummaryOps.cu. That uses atomicAdds, which lead to the non-determinism and are actually of poor performance when many threads want to write to the same memory location.

I think one (hacky) way to avoid this currently is to use t.put_ with accumulate=True, but that isn’t batched etc.

Best regards

Thomas

Thanks for the reply!