Suppose I have the following two matrices:

```
x = torch.randint(0, 256, (100000, 48), dtype=torch.uint8)
x_new = torch.randint(0, 256, (1000, 48), dtype=torch.uint8)
```

I wish to do a matrix multiplication like operation where I compare the 48 dimensions and sum up all the elements that are equal. The following operation takes 7.81 seconds. Batching does not seem to help:

```
matrix = (x_new.unsqueeze(1) == x).sum(dim=-1)
```

However, doing a simple matrix multiplication (`matrix = x_new @ x.T`

) takes 3.54 seconds. I understand this is most likely calling a deeper library that isnâ€™t slowed down by python. However, the question is, **is there a way to speed up the multiplication like operation?** by using scripting, or any other way at all?

What is even stranger though is that if I do `matrix = x_new.float() @ x.float().T`

this operation takes 214ms. This is more than 10x faster than the `uint8`

multiplication. To add to the strangeness, uint8 is a single byte whereas float is 4 bytes, so I would have expected a speed up.

For context, I am trying to quantize vectors so that I can find the closest vector by comparing integers than directly doing dot products.