How to achieve efficient Generalized Matrix Product in PyTorch?

I have two real-valued matrices A and B with the shapes of (m, n) and (n, p), respectively.
I know there are fast APIs (e.g. mm, matmul, einsum) in PyTorch to achieve matrix multiplication, but I want to know if there exists an efficient way to change the dimension-reduction function from inner product to the operation I indicated (e.g. L2 norm of residual)? :thinking:
Specifically, I want to achieve the following function f efficiently:

import torch

m, n, p = 3, 4, 5

A = torch.rand(m, n)
B = torch.rand(n, p)

def g(a, b):
    return (a - b).norm()

# The mysterious efficient implementation I want
C = f(A, B, g)

where C is with a shape of (m, p), and the element C[i,j] equals to g(A[i], B[:, j]), and actually g could be with other functions (e.g. trivial summation, L1 norm of concatenated vector).
I made some efforts but my solutions are all with inefficient nested for-loops. :sweat_smile:

# My inefficient implementation :(
C = torch.empty(m, p)
for i in range(m):
    for j in range(p):
        C[i, j] = g(A[i], B[:, j])

I believe developing extra CUDA operators can achieve this, but I am totally ignorant for CUDA-programming and prefer an elegant implementation.
Thanks.

Hi Bin!

I think, in general, if your g() is built from a broadcastable operation
and a reduction operation, you can accomplish your goal with pytorch
broadcasting.

In your specific case, you can use unsqueeze() in the appropriate
locations to replace the nested loops with broadcasting, and then
apply norm():

>>> torch.__version__
'1.9.0'
>>>
>>> _ = torch.manual_seed (2021)
>>>
>>> m, n, p = 3, 4, 5
>>>
>>> A = torch.rand(m, n)
>>> B = torch.rand(n, p)
>>>
>>> def g(a, b):
...     return (a - b).norm()
...
>>> C = torch.empty(m, p)
>>> for i in range(m):
...     for j in range(p):
...         C[i, j] = g(A[i], B[:, j])
...
>>> C
tensor([[1.0094, 0.6486, 0.8481, 0.7942, 0.4473],
        [1.3406, 1.1867, 0.9334, 0.7280, 0.8513],
        [0.5945, 0.7856, 0.9520, 1.2410, 0.7428]])
>>>
>>> (A.unsqueeze (-1) - B.unsqueeze (0)).norm (dim = 1)
tensor([[1.0094, 0.6486, 0.8481, 0.7942, 0.4473],
        [1.3406, 1.1867, 0.9334, 0.7280, 0.8513],
        [0.5945, 0.7856, 0.9520, 1.2410, 0.7428]])

Best.

K. Frank

1 Like

I think this approach can achieve the function I wanted, but I think when m, n and p are large, the memory complexity of this method are much higher than the trivial matrix product (O(mnp) vs. O(mp)). When this operation is applied into some neural networks (e.g. Transformer), the memory cost may be very huge.

Hi Bin!

Yes, I believe that you are correct about this.

When the tensors are broadcast, I do not believe (to the best of my
knowledge) that the broadcast tensors are actually materialized.

But (again to the best of my knowledge) the “outer-difference” tensor
that then becomes the argument to the .norm() most likely is
materialized, so I do expect the memory cost you describe.

(If you can afford the memory cost, I do expect the broadcast version
to run much faster than a for-loop implementation.)

As an aside, pytorch’s JIT talks about “fusing kernels,” but it talks
specifically about fusing “pointwise” operations, so I don’t think that
JIT applies to your use case. (I could be wrong.)

Best.

K. Frank

1 Like