Matrix Multiplication as Reduction

I am interested in matrix-multiplying many matrices stored in a single tensor. By analogy, let me call this “reducing via matrix multiplication”. As an example,


#!/usr/bin/env python

import torch
torch.manual_seed(1234)

t = 4
x = 2
v = torch.rand(t, x, x)

print(f'''Here is a random tensor of shape {t=} by {x=} by {x=}
''')
print(v)

print('''

Here we reduce v by summing along axis 0
''')
summed = v.sum(axis=0)
print(summed)

print('''

But notice!  Each slice of axis 0 is an x by x matrix.
So there's a lot more we can do.
In particular, it makes sense to matrix multiply different slices of v together!

We could multiply them in any order, in principle.
But clear thinking and well-organized code presumably already has them in the right order.
''')

reduced = v[0]
for i in range(1, t):
    reduced = torch.matmul(v[i], reduced)

print('''The reduced matrix is

''')
print(reduced)

print('''

We might want the equivalent of a 'running sum'.
That is, all the results of the intermediate multiplies.
''')
folded = torch.zeros_like(v)
folded[0] = v[0]

for i in range(1, t):
    folded[i] = torch.matmul(v[i], folded[i-1])

print(folded)

which yields

> ./mm_reduction.py
Here is a random tensor of shape t=4 by x=2 by x=2

tensor([[[0.0290, 0.4019],
         [0.2598, 0.3666]],

        [[0.0583, 0.7006],
         [0.0518, 0.4681]],

        [[0.6738, 0.3315],
         [0.7837, 0.5631]],

        [[0.7749, 0.8208],
         [0.2793, 0.6817]]])


Here we reduce v by summing along axis 0

tensor([[1.5359, 2.2548],
        [1.3746, 2.0796]])


But notice!  Each slice of axis 0 is an x by x matrix.
So there's a lot more we can do.
In particular, it makes sense to matrix multiply different slices of v together!

We could multiply them in any order, in principle.
But clear thinking and well-organized code presumably already has them in the right order.

The reduced matrix is


tensor([[0.3027, 0.4650],
        [0.1914, 0.2942]])


We might want the equivalent of a 'running sum'.
That is, all the results of the intermediate multiplies.

tensor([[[0.0290, 0.4019],
         [0.2598, 0.3666]],

        [[0.1837, 0.2803],
         [0.1231, 0.1925]],

        [[0.1646, 0.2527],
         [0.2133, 0.3281]],

        [[0.3027, 0.4650],
         [0.1914, 0.2942]]])

This tiny example is very quick, of course, because everything is so small. But when t is large and x is of intermediate size, it results in very many launches of CUDA kernels on a GPU. Is there a sensible call to accomplish this ‘matrix multiplication reduction’? What about the reduction which yields the intermediates?

I wrote my example in pytorch, but I’m most interested in solutions in C++ libtorch.

If not, how can I craft a custom CUDA kernel accomplishing this reduction and coax libtorch to use it?

Here is another similar reduction I’m interested in for my application

print(f'''

We could also go "all the way around",
meaning we'd get {t=} matrix products.
''')

u = v.clone()
w = v.clone()
for i in range(1, t):
    u = u.roll(-1, dims=(0,))
    w = torch.bmm(u, w)

print(w)

yielding

We could also go "all the way around",
meaning we'd get t=4 matrix products.

tensor([[[0.3027, 0.4650],
         [0.1914, 0.2942]],

        [[0.0299, 0.3265],
         [0.0518, 0.5670]],

        [[0.4166, 0.2657],
         [0.2825, 0.1802]],

        [[0.1981, 0.3074],
         [0.2569, 0.3987]]])

In case you haven’t already, have you investigated whether it is possible to express your reductions/multiplications via torch.einsum? torch.einsum — PyTorch 2.0 documentation

Yes. This operation is not einsummable in any straightforward way.