I have two real-valued matrices A and B with the shapes of (m, n) and (n, p), respectively.
I know there are fast APIs (e.g. mm, matmul, einsum) in PyTorch to achieve matrix multiplication, but I want to know if there exists an efficient way to change the dimension-reduction function from inner product to the operation I indicated (e.g. L2 norm of residual)?
Specifically, I want to achieve the following function f efficiently:
import torch
m, n, p = 3, 4, 5
A = torch.rand(m, n)
B = torch.rand(n, p)
def g(a, b):
return (a - b).norm()
# The mysterious efficient implementation I want
C = f(A, B, g)
where C is with a shape of (m, p), and the element C[i,j] equals to g(A[i], B[:, j]), and actually g could be with other functions (e.g. trivial summation, L1 norm of concatenated vector).
I made some efforts but my solutions are all with inefficient nested for-loops.
# My inefficient implementation :(
C = torch.empty(m, p)
for i in range(m):
for j in range(p):
C[i, j] = g(A[i], B[:, j])
I believe developing extra CUDA operators can achieve this, but I am totally ignorant for CUDA-programming and prefer an elegant implementation.
Thanks.