Suppose I have the following two matrices:
x = torch.randint(0, 256, (100000, 48), dtype=torch.uint8) x_new = torch.randint(0, 256, (1000, 48), dtype=torch.uint8)
I wish to do a matrix multiplication like operation where I compare the 48 dimensions and sum up all the elements that are equal. The following operation takes 7.81 seconds. Batching does not seem to help:
matrix = (x_new.unsqueeze(1) == x).sum(dim=-1)
However, doing a simple matrix multiplication (
matrix = x_new @ x.T) takes 3.54 seconds. I understand this is most likely calling a deeper library that isn’t slowed down by python. However, the question is, is there a way to speed up the multiplication like operation? by using scripting, or any other way at all?
What is even stranger though is that if I do
matrix = x_new.float() @ x.float().T this operation takes 214ms. This is more than 10x faster than the
uint8 multiplication. To add to the strangeness, uint8 is a single byte whereas float is 4 bytes, so I would have expected a speed up.
For context, I am trying to quantize vectors so that I can find the closest vector by comparing integers than directly doing dot products.