How to multiply two matrices but only summing only over topK largest values

I have to real-valued matrix matrices A and B with shape (m,n) and (n,p) respectively and a pre-defined constant K. The goal is to calculate the new “topk multiplication” function, where entry (i,j) in the output is calculated as follows:

  1. Perform element-wise multiplication between the i-th row of A and the j-th column of B
  2. Summing only top K largest values from the result in step 1).

Here, if K = n, then this operation reduces to the standard matrix multiplication. My interest is calculating this topk multiplication function efficiently, especially when the inputs are in batch. An example of the naive implementation can be found below. Thank you!

import torch

a = torch.rand(2,3,5)
b = torch.rand(2,5,2)
K = 2

out = torch.zeros(2,3,2)

for i in range(2):
    for j in range(3):
        for k in range(2):
            tmp = torch.mul(a[i,j,:], b[i,:,k])
            v, _ = torch.topk(tmp, K)
            out[i,j,k] = v.sum()

Hi Quang!

If you align the indices correctly and use broadcasting, you can replace
the loops with a single element-wise tensor multiplication:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> a = torch.rand(2,3,5)
>>> b = torch.rand(2,5,2)
>>> K = 2
>>>
>>> out = torch.zeros(2,3,2)
>>>
>>> for i in range(2):
...     for j in range(3):
...         for k in range(2):
...             tmp = torch.mul(a[i,j,:], b[i,:,k])
...             v, _ = torch.topk(tmp, K)
...             out[i,j,k] = v.sum()
...
>>> outB = (a.unsqueeze (3) * b.unsqueeze (1)).topk (K, dim = 2)[0].sum (dim = 2)
>>>
>>> torch.equal (out, outB)
True

(You could instead use einsum() to “align the indices,” but using
.unsqueeze() seems a little better to me stylistically.)

Best.

K. Frank

Hi Frank

Your solution worked well, thanks so much!

Quang