I have two real-valued matrices A and B with the shapes of (m, n) and (n, p), respectively.
I know there are fast APIs (e.g. mm, matmul, einsum) in PyTorch to achieve matrix multiplication, but I want to know if there exists an efficient way to change the dimension-reduction function from inner product to the operation I indicated (e.g. L2 norm of residual)?
Specifically, I want to achieve the following function f efficiently:
import torch m, n, p = 3, 4, 5 A = torch.rand(m, n) B = torch.rand(n, p) def g(a, b): return (a - b).norm() # The mysterious efficient implementation I want C = f(A, B, g)
where C is with a shape of (m, p), and the element C[i,j] equals to g(A[i], B[:, j]), and actually g could be with other functions (e.g. trivial summation, L1 norm of concatenated vector).
I made some efforts but my solutions are all with inefficient nested for-loops.
# My inefficient implementation :( C = torch.empty(m, p) for i in range(m): for j in range(p): C[i, j] = g(A[i], B[:, j])
I believe developing extra CUDA operators can achieve this, but I am totally ignorant for CUDA-programming and prefer an elegant implementation.