Multiplying large batches of small matrices fast

What’s the best way to multiply a lot of small matrices in pytorch? I.e., say I want to multiply two batches of 2x2 matrices, with a large batch size. I seem to be getting better results by explicitly writing out the matrix multiplication than by using torch.bmm. Example below:

import torch 

def fast_2x2_bmm(a, b):
    # There should be faster methods...
    return torch.stack(
                    a[..., 0, 0]*b[..., 0, 0] + a[..., 0, 1]*b[..., 1, 0],
                    a[..., 1, 0]*b[..., 0, 0] + a[..., 1, 1]*b[..., 1, 0],
                    a[..., 0, 0]*b[..., 0, 1] + a[..., 0, 1]*b[..., 1, 1],
                    a[..., 1, 0]*b[..., 0, 1] + a[..., 1, 1]*b[..., 1, 1],
fast_2x2_bmm = torch.compile(fast_2x2_bmm)

def fast_3x3_bmm(a, b):
    # There should be faster methods...
    return torch.stack(
                    a[..., 0, 0]*b[..., 0, 0] + a[..., 0, 1]*b[..., 1, 0] + a[..., 0, 2]*b[..., 2, 0],
                    a[..., 1, 0]*b[..., 0, 0] + a[..., 1, 1]*b[..., 1, 0] + a[..., 1, 2]*b[..., 2, 0],
                    a[..., 2, 0]*b[..., 0, 0] + a[..., 2, 1]*b[..., 1, 0] + a[..., 2, 2]*b[..., 2, 0],
                    a[..., 0, 0]*b[..., 0, 1] + a[..., 0, 1]*b[..., 1, 1] + a[..., 0, 2]*b[..., 2, 1],
                    a[..., 1, 0]*b[..., 0, 1] + a[..., 1, 1]*b[..., 1, 1] + a[..., 1, 2]*b[..., 2, 1],
                    a[..., 2, 0]*b[..., 0, 1] + a[..., 2, 1]*b[..., 1, 1] + a[..., 2, 2]*b[..., 2, 1],
                    a[..., 0, 0]*b[..., 0, 2] + a[..., 0, 1]*b[..., 1, 2] + a[..., 0, 2]*b[..., 2, 2],
                    a[..., 1, 0]*b[..., 0, 2] + a[..., 1, 1]*b[..., 1, 2] + a[..., 1, 2]*b[..., 2, 2],
                    a[..., 2, 0]*b[..., 0, 2] + a[..., 2, 1]*b[..., 1, 2] + a[..., 2, 2]*b[..., 2, 2],
fast_3x3_bmm = torch.compile(fast_3x3_bmm)

def torch_bmm(a, b):
    return torch.bmm(a, b)
torch_bmm = torch.compile(torch_bmm)

if __name__ == "__main__":
    from torch.profiler import profile, record_function, ProfilerActivity
    B = 262144
    D = 2
    if D not in [2, 3]: raise NotImplementedError()
    fast_bmm = fast_2x2_bmm if D == 2 else fast_3x3_bmm

    # run functions
    torch_bmm(torch.randn([B, D, D]).cuda(), torch.randn([B, D, D]).cuda())
    fast_bmm(torch.randn([B, D, D]).cuda(), torch.randn([B, D, D]).cuda())

    mat0 = torch.randn([B, D, D]).cuda() 
    mat1 = torch.randn([B, D, D]).cuda() 
    assert torch.allclose(
        torch_bmm(mat0, mat1),
        fast_bmm(mat0, mat1),
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    ) as prof:
        torch.bmm(mat0, mat1)  # warm up?
        with record_function("bmm"):
            torch.bmm(mat0, mat1)
        with record_function("compiled bmm"):
            torch_bmm(mat0, mat1)
        with record_function("compiled fast bmm"):
            fast_bmm(mat0, mat1)
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

I get 886 us CUDA time for torch.bmm, but 22 us for explicitly writing out the multiplication.

Horace He had the following suggestion on twitter

def fast_bmm(a, b):
    return (
        * b.unsqueeze(-3)

which seems as fast as the “written out” solution above when using torch.compile, and faster when not using torch.compile. Profiling code here: