I have to real-valued matrix matrices A and B with shape (m,n) and (n,p) respectively and a pre-defined constant K. The goal is to calculate the new “topk multiplication” function, where entry (i,j) in the output is calculated as follows:

Perform element-wise multiplication between the i-th row of A and the j-th column of B

Summing only top K largest values from the result in step 1).

Here, if K = n, then this operation reduces to the standard matrix multiplication. My interest is calculating this topk multiplication function efficiently, especially when the inputs are in batch. An example of the naive implementation can be found below. Thank you!

import torch
a = torch.rand(2,3,5)
b = torch.rand(2,5,2)
K = 2
out = torch.zeros(2,3,2)
for i in range(2):
for j in range(3):
for k in range(2):
tmp = torch.mul(a[i,j,:], b[i,:,k])
v, _ = torch.topk(tmp, K)
out[i,j,k] = v.sum()