Does torch.compile try to reorder matmuls for efficiency

I’m just curious if torch.compile is able to perform opt_einsum style optimizations, where the order of matrix multiplications is optimized to reduce compute.

The minimal example here is

@torch.compile
def matmul(A, B, C):
    return A @ B @ C

In the case where, say, A is 1000 x 100, B is 100 x 10, and C is 10 x 1, it is clearly more efficient to perform the matmul as A @ (B @ C), where the last two matrices can be multiplied first due to associativity. I’m wondering if this is something torch.compile can optimize for.

Hi Pea!

This does not appear to be something that torch.compile optimizes.

It’s speculation on my part, but I don’t think that torch.compile knows
that it is doing matrix multiplication (nor realizes that matrix multiplication
is associative).

Note that torch.linalg.multi_dot() has this matrix-associativity optimization
as its specific purpose.

Here is a script that shows no torch.compile speed-up for your minimal
example, while showing speed-up when explicitly performing B @ C first
and when letting multi_dot() perform the optimization:

import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())
print (torch.cuda.get_device_capability())

import time

def matmul_no_compile (A, B, C):
    return A @ B @ C                 # equivalent to (A @ B) @ c

@torch.compile (backend = 'eager')   # 'eager' for older gpu
def matmul_compile (A, B, C):
    return A @ B @ C                 # equivalent to (A @ B) @ c

def matmul_rl (A, B, C):
    return A @ (B @ C)               # explicit efficient ordering

Aa = []
Bb = []
Cc = []

print ('generating data ...')
for  i in range (10):
    Aa.append (torch.randn (10000, 1000, device = 'cuda'))
    Bb.append (torch.randn (1000, 100, device = 'cuda'))
    Cc.append (torch.randn (100, 10, device = 'cuda'))

print ('run timing loops ...')

# matmul_no_compile

# run warmup
for  i in range (10):
    for  A, B, C in zip (Aa, Bb, Cc):
        D = matmul_no_compile (A, B, C)

# timing loop
torch.cuda.synchronize()
t0 = time.time()
for  i in range (10):
    for  A, B, C in zip (Aa, Bb, Cc):
        D = matmul_no_compile (A, B, C)

torch.cuda.synchronize()
t1 = time.time()
print ('matmul_no_compile:  elapsed time: ', t1 - t0)

# matmul_compile

# run warmup
for  i in range (10):
    for  A, B, C in zip (Aa, Bb, Cc):
        D = matmul_compile (A, B, C)

# timing loop
torch.cuda.synchronize()
t0 = time.time()
for  i in range (10):
    for  A, B, C in zip (Aa, Bb, Cc):
        D = matmul_compile (A, B, C)

torch.cuda.synchronize()
t1 = time.time()
print ('matmul_compile:     elapsed time: ', t1 - t0)

# matmul_rl

# run warmup
for  i in range (10):
    for  A, B, C in zip (Aa, Bb, Cc):
        D = matmul_rl (A, B, C)

# timing loop
torch.cuda.synchronize()
t0 = time.time()
for  i in range (10):
    for  A, B, C in zip (Aa, Bb, Cc):
        D = matmul_rl (A, B, C)

torch.cuda.synchronize()
t1 = time.time()
print ('matmul_rl:          elapsed time: ', t1 - t0)

# multi_dot -- knows how to reorder matrix multiplications

# run warmup
for  i in range (10):
    for  A, B, C in zip (Aa, Bb, Cc):
        D = torch.linalg.multi_dot ((A, B, C))

# timing loop
torch.cuda.synchronize()
t0 = time.time()
for  i in range (10):
    for  A, B, C in zip (Aa, Bb, Cc):
        D = torch.linalg.multi_dot ((A, B, C))

torch.cuda.synchronize()
t1 = time.time()
print ('multi_dot:          elapsed time: ', t1 - t0)

And here is its output:

2.1.0
11.8
GeForce GTX 1050 Ti
(6, 1)
generating data ...
run timing loops ...
matmul_no_compile:  elapsed time:  0.18052339553833008
matmul_compile:     elapsed time:  0.1767871379852295
matmul_rl:          elapsed time:  0.056723833084106445
multi_dot:          elapsed time:  0.05632662773132324

Best.

K. Frank

1 Like

Thank you Frank!

I’ve adapted your code to run on larger matrices, with tensorfloat32, and including torch.einsum in the mix. And it turns out that you are completely right that torch.compile seems to make no effort in reordering matmuls for efficiency.

import torch

print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())
print (torch.cuda.get_device_capability())

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

import time


def matmul_no_compile(A, B, C):
    return A @ B @ C                 # equivalent to (A @ B) @ c

@torch.compile
def matmul_compile(A, B, C):
    return A @ B @ C                 # equivalent to (A @ B) @ c

def matmul_rl(A, B, C):
    return A @ (B @ C)               # explicit efficient ordering

def multi_dot(A, B, C):
    return torch.linalg.multi_dot((A, B, C))

def einsum(A, B, C):
    return torch.einsum('ij,jk,kl->il', A, B, C)

Aa, Bb, Cc = [], [], []

print ('generating data ...')
for  i in range (10):
    Aa.append (torch.randn (2 ** 16, 2 ** 12, device = 'cuda'))
    Bb.append (torch.randn (2 ** 12, 2 ** 12, device = 'cuda'))
    Cc.append (torch.randn (2 ** 12, 2 ** 8, device = 'cuda'))
    
print(f"matrices shapes: {Aa[0].shape} {Bb[0].shape} {Cc[0].shape}")

print ('run timing loops ...')

# matmul_no_compile

def run_benchmark(fn, Aa, Bb, Cc):
    # warmup
    for i in range(10):
        for (A, B, C) in zip(Aa, Bb, Cc):
            D = fn(A, B, C)
    
    # timeing
    torch.cuda.synchronize()
    t0 = time.time()
    for  i in range (10):
        for  A, B, C in zip (Aa, Bb, Cc):
            D = fn(A, B, C)
            
    torch.cuda.synchronize()
    t1 = time.time()
    print(f'{fn.__name__}: {t1 - t0}')


run_benchmark(matmul_no_compile, Aa, Bb, Cc)
run_benchmark(matmul_compile, Aa, Bb, Cc)
run_benchmark(matmul_rl, Aa, Bb, Cc)
run_benchmark(multi_dot, Aa, Bb, Cc)
run_benchmark(einsum, Aa, Bb, Cc)

output

2.1.0
11.8
NVIDIA A30
(8, 0)
generating data ...
matrices shapes: torch.Size([65536, 4096]) torch.Size([4096, 4096]) torch.Size([4096, 256])
run timing loops ...
matmul_no_compile: 4.105190992355347
matmul_compile: 4.124130725860596
matmul_rl: 0.31899595260620117
multi_dot: 0.31679391860961914
einsum: 0.32097911834716797

I guess the question is now: should this be something torch.compile be trying to support? I’m not familiar enough with compilation to know how difficult this would be, but I’m guessing the pattern of a chain of matmuls should be preserved down to the IR (?) and it should be possible to reorder the matmuls there (at least for simple patterns such as a linear chain).

Hi Pea!

I’m glad that you confirmed this on a more modem gpu (and hence a
different compilation backend).

That wouldn’t seem unreasonable to me. I don’t think that proposing
this as a feature request on github would be poorly received.

I don’t know anything about the internals of torch.compile, but this all
sounds plausible to me. The logic is straightforward enough (witness
multi_dot()). I just don’t know whether torch.compile has hooks in
the right places to make it easy.

I don’t think, however, that chains of matrix multiplications (without
intervening non-linearities or additive biases) are that common of
a use case in pytorch, so implementing something like this – even
if relatively easy – might not be a priority.

(If you do file a feature request, I would suggest including a couple of
real-world use cases to motivate the feature.)

Best.

K. Frank