I am interested in matrix-multiplying many matrices stored in a single tensor. By analogy, let me call this “reducing via matrix multiplication”. As an example,
#!/usr/bin/env python
import torch
torch.manual_seed(1234)
t = 4
x = 2
v = torch.rand(t, x, x)
print(f'''Here is a random tensor of shape {t=} by {x=} by {x=}
''')
print(v)
print('''
Here we reduce v by summing along axis 0
''')
summed = v.sum(axis=0)
print(summed)
print('''
But notice! Each slice of axis 0 is an x by x matrix.
So there's a lot more we can do.
In particular, it makes sense to matrix multiply different slices of v together!
We could multiply them in any order, in principle.
But clear thinking and well-organized code presumably already has them in the right order.
''')
reduced = v[0]
for i in range(1, t):
reduced = torch.matmul(v[i], reduced)
print('''The reduced matrix is
''')
print(reduced)
print('''
We might want the equivalent of a 'running sum'.
That is, all the results of the intermediate multiplies.
''')
folded = torch.zeros_like(v)
folded[0] = v[0]
for i in range(1, t):
folded[i] = torch.matmul(v[i], folded[i-1])
print(folded)
which yields
> ./mm_reduction.py
Here is a random tensor of shape t=4 by x=2 by x=2
tensor([[[0.0290, 0.4019],
[0.2598, 0.3666]],
[[0.0583, 0.7006],
[0.0518, 0.4681]],
[[0.6738, 0.3315],
[0.7837, 0.5631]],
[[0.7749, 0.8208],
[0.2793, 0.6817]]])
Here we reduce v by summing along axis 0
tensor([[1.5359, 2.2548],
[1.3746, 2.0796]])
But notice! Each slice of axis 0 is an x by x matrix.
So there's a lot more we can do.
In particular, it makes sense to matrix multiply different slices of v together!
We could multiply them in any order, in principle.
But clear thinking and well-organized code presumably already has them in the right order.
The reduced matrix is
tensor([[0.3027, 0.4650],
[0.1914, 0.2942]])
We might want the equivalent of a 'running sum'.
That is, all the results of the intermediate multiplies.
tensor([[[0.0290, 0.4019],
[0.2598, 0.3666]],
[[0.1837, 0.2803],
[0.1231, 0.1925]],
[[0.1646, 0.2527],
[0.2133, 0.3281]],
[[0.3027, 0.4650],
[0.1914, 0.2942]]])
This tiny example is very quick, of course, because everything is so small. But when t is large and x is of intermediate size, it results in very many launches of CUDA kernels on a GPU. Is there a sensible call to accomplish this ‘matrix multiplication reduction’? What about the reduction which yields the intermediates?
I wrote my example in pytorch, but I’m most interested in solutions in C++ libtorch.
If not, how can I craft a custom CUDA kernel accomplishing this reduction and coax libtorch to use it?