Efficient (restricted) matrix multiplication implementation using indicator matrix

Hi all,

I have a dot product which I aim to optimise by using some a priori information.
I have two matrices, i.e. Q (n x d) and K (d x n), and since n can be very large this results in quadratic complexity. But I also have a symmetric indicator matrix A (n x n) where each row has knn ones and n-knn zeros (knn < n, always). I use this matrix A to select the relevant columns in K for each row in Q. So now, my question is, how can I efficiently implement the outlined code part below? And, how can I store A as efficient as possible in terms of memory?

import torch
import time
import random
from sklearn.neighbors import kneighbors_graph

n = 1024
d = 512
knn = 64
q = torch.randn(n, d)
k = torch.randn(d, n)
coords = torch.tensor([(random.random(), random.random()) for _ in range(n)])
A = torch.tensor(kneighbors_graph(coords, knn, mode=‘connectivity’, include_self=False).todense()).bool()

t1_start = time.time()

#start optimise; goal: this should be more efficient than full dot product

outs = [q[i].matmul(k[:,A[i]]) for i in range(n)]
outs = torch.stack(outs)

#end optimise

t1_stop = time.time()
time_need = t1_stop - t1_start
print(time_need) # 0.1741

# this is the full dot product
t1_start = time.time()
outs=q.matmul(k)
t1_stop = time.time()
time_need = t1_stop - t1_start
print(time_need) #0.00365

Hi Agent!

You can package the set of dot products into a single einsum()
operation (“Einstein summation,” a generalized matrix multiplication),
thereby removing the loop (expressed as a python list comprehension),
by “expanding” the matrices, at the cost of materializing a d x n**2
expanded version of k:

>>> import torch
>>> torch.__version__
'1.10.2'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> n = 7
>>> d = 5
>>> r = 3   # knn
>>>
>>> q = torch.randn (n, d)
>>> k = torch.randn (d, n)
>>>
>>> # build (non-symmetric) boolean indicator matrix
...
>>> A = torch.randn (n, n).argsort (dim = 1) < r
>>>
>>> torch.all (A.sum (dim = 1) == r)   # verify correct number of "True"s in rows of A
tensor(True)
>>>
>>> # version with loop
...
>>> outs = [q[i].matmul(k[:,A[i]]) for i in range(n)]
>>> outs = torch.stack(outs)
>>>
>>> # loop-free version, but with d x n**2 "expanded" matrix
...
>>> qex = q.unsqueeze (1).expand (n, r, d).reshape (r * n, d)
>>> kex = k.unsqueeze (1).expand (d, n, n).reshape (d, n * n)[:, A.view (n * n)]
>>> outsB = torch.einsum ('ij, ji -> i', qex, kex).view (n, r)
>>>
>>> torch.allclose (outs, outsB)   # cross-check results
True
>>>
>>> # shapes of "expanded" matrices and result
...
>>> qex.shape
torch.Size([21, 5])
>>> kex.shape
torch.Size([5, 21])
>>>
>>> outsB
tensor([[-0.5084,  3.1422,  0.9380],
[ 0.0939,  2.3767,  0.5234],
[-3.2668, -0.2849, -1.3499],
[-2.0864,  1.1269,  0.9776],
[ 0.1827,  0.2465, -1.3500],
[ 3.4298,  2.5439,  0.3097],
[ 0.6107, -1.6878, -2.8169]])

Assuming that knn is significantly less than n, the information in A will
be stored most efficiently as a list (tensor) of the indices of its True
elements. If you flatten A first in to a vector (of length n**2), you save
space by using only a single index per element (and this structure is
also convenient for the above “expanded” matrix approach):

>>> # express indicator matrix more compactly as list of (flattened) indices
...
>>> Aind = torch.where (A.view (n * n))[0]
>>> Aind
tensor([ 2,  3,  5,  9, 12, 13, 16, 19, 20, 21, 26, 27, 29, 31, 33, 35, 37, 41,
43, 47, 48])
>>> kexB = k.unsqueeze (1).expand (d, n, n).reshape (d, n * n)[:, Aind]   # expanded k
>>>
>>> torch.equal (kex, kexB)
True

Best.

K. Frank