# Matrix Multiplication as Reduction

I am interested in matrix-multiplying many matrices stored in a single tensor. By analogy, let me call this “reducing via matrix multiplication”. As an example,

``````
#!/usr/bin/env python

import torch
torch.manual_seed(1234)

t = 4
x = 2
v = torch.rand(t, x, x)

print(f'''Here is a random tensor of shape {t=} by {x=} by {x=}
''')
print(v)

print('''

Here we reduce v by summing along axis 0
''')
summed = v.sum(axis=0)
print(summed)

print('''

But notice!  Each slice of axis 0 is an x by x matrix.
So there's a lot more we can do.
In particular, it makes sense to matrix multiply different slices of v together!

We could multiply them in any order, in principle.
But clear thinking and well-organized code presumably already has them in the right order.
''')

reduced = v
for i in range(1, t):
reduced = torch.matmul(v[i], reduced)

print('''The reduced matrix is

''')
print(reduced)

print('''

We might want the equivalent of a 'running sum'.
That is, all the results of the intermediate multiplies.
''')
folded = torch.zeros_like(v)
folded = v

for i in range(1, t):
folded[i] = torch.matmul(v[i], folded[i-1])

print(folded)
``````

which yields

``````> ./mm_reduction.py
Here is a random tensor of shape t=4 by x=2 by x=2

tensor([[[0.0290, 0.4019],
[0.2598, 0.3666]],

[[0.0583, 0.7006],
[0.0518, 0.4681]],

[[0.6738, 0.3315],
[0.7837, 0.5631]],

[[0.7749, 0.8208],
[0.2793, 0.6817]]])

Here we reduce v by summing along axis 0

tensor([[1.5359, 2.2548],
[1.3746, 2.0796]])

But notice!  Each slice of axis 0 is an x by x matrix.
So there's a lot more we can do.
In particular, it makes sense to matrix multiply different slices of v together!

We could multiply them in any order, in principle.
But clear thinking and well-organized code presumably already has them in the right order.

The reduced matrix is

tensor([[0.3027, 0.4650],
[0.1914, 0.2942]])

We might want the equivalent of a 'running sum'.
That is, all the results of the intermediate multiplies.

tensor([[[0.0290, 0.4019],
[0.2598, 0.3666]],

[[0.1837, 0.2803],
[0.1231, 0.1925]],

[[0.1646, 0.2527],
[0.2133, 0.3281]],

[[0.3027, 0.4650],
[0.1914, 0.2942]]])
``````

This tiny example is very quick, of course, because everything is so small. But when t is large and x is of intermediate size, it results in very many launches of CUDA kernels on a GPU. Is there a sensible call to accomplish this ‘matrix multiplication reduction’? What about the reduction which yields the intermediates?

I wrote my example in pytorch, but I’m most interested in solutions in C++ libtorch.

If not, how can I craft a custom CUDA kernel accomplishing this reduction and coax libtorch to use it?

Here is another similar reduction I’m interested in for my application

``````print(f'''

We could also go "all the way around",
meaning we'd get {t=} matrix products.
''')

u = v.clone()
w = v.clone()
for i in range(1, t):
u = u.roll(-1, dims=(0,))
w = torch.bmm(u, w)

print(w)
``````

yielding

``````We could also go "all the way around",
meaning we'd get t=4 matrix products.

tensor([[[0.3027, 0.4650],
[0.1914, 0.2942]],

[[0.0299, 0.3265],
[0.0518, 0.5670]],

[[0.4166, 0.2657],
[0.2825, 0.1802]],

[[0.1981, 0.3074],
[0.2569, 0.3987]]])
``````

In case you haven’t already, have you investigated whether it is possible to express your reductions/multiplications via `torch.einsum`? torch.einsum — PyTorch 2.0 documentation

Yes. This operation is not einsummable in any straightforward way.