I’m just curious if torch.compile is able to perform opt_einsum style optimizations, where the order of matrix multiplications is optimized to reduce compute.
The minimal example here is
@torch.compile
def matmul(A, B, C):
return A @ B @ C
In the case where, say, A is 1000 x 100, B is 100 x 10, and C is 10 x 1, it is clearly more efficient to perform the matmul as A @ (B @ C), where the last two matrices can be multiplied first due to associativity. I’m wondering if this is something torch.compile can optimize for.
This does not appear to be something that torch.compile optimizes.
It’s speculation on my part, but I don’t think that torch.compile knows
that it is doing matrix multiplication (nor realizes that matrix multiplication
is associative).
Note that torch.linalg.multi_dot() has this matrix-associativity optimization
as its specific purpose.
Here is a script that shows no torch.compile speed-up for your minimal
example, while showing speed-up when explicitly performing B @ C first
and when letting multi_dot() perform the optimization:
import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())
print (torch.cuda.get_device_capability())
import time
def matmul_no_compile (A, B, C):
return A @ B @ C # equivalent to (A @ B) @ c
@torch.compile (backend = 'eager') # 'eager' for older gpu
def matmul_compile (A, B, C):
return A @ B @ C # equivalent to (A @ B) @ c
def matmul_rl (A, B, C):
return A @ (B @ C) # explicit efficient ordering
Aa = []
Bb = []
Cc = []
print ('generating data ...')
for i in range (10):
Aa.append (torch.randn (10000, 1000, device = 'cuda'))
Bb.append (torch.randn (1000, 100, device = 'cuda'))
Cc.append (torch.randn (100, 10, device = 'cuda'))
print ('run timing loops ...')
# matmul_no_compile
# run warmup
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = matmul_no_compile (A, B, C)
# timing loop
torch.cuda.synchronize()
t0 = time.time()
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = matmul_no_compile (A, B, C)
torch.cuda.synchronize()
t1 = time.time()
print ('matmul_no_compile: elapsed time: ', t1 - t0)
# matmul_compile
# run warmup
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = matmul_compile (A, B, C)
# timing loop
torch.cuda.synchronize()
t0 = time.time()
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = matmul_compile (A, B, C)
torch.cuda.synchronize()
t1 = time.time()
print ('matmul_compile: elapsed time: ', t1 - t0)
# matmul_rl
# run warmup
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = matmul_rl (A, B, C)
# timing loop
torch.cuda.synchronize()
t0 = time.time()
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = matmul_rl (A, B, C)
torch.cuda.synchronize()
t1 = time.time()
print ('matmul_rl: elapsed time: ', t1 - t0)
# multi_dot -- knows how to reorder matrix multiplications
# run warmup
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = torch.linalg.multi_dot ((A, B, C))
# timing loop
torch.cuda.synchronize()
t0 = time.time()
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = torch.linalg.multi_dot ((A, B, C))
torch.cuda.synchronize()
t1 = time.time()
print ('multi_dot: elapsed time: ', t1 - t0)
I’ve adapted your code to run on larger matrices, with tensorfloat32, and including torch.einsum in the mix. And it turns out that you are completely right that torch.compile seems to make no effort in reordering matmuls for efficiency.
import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())
print (torch.cuda.get_device_capability())
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import time
def matmul_no_compile(A, B, C):
return A @ B @ C # equivalent to (A @ B) @ c
@torch.compile
def matmul_compile(A, B, C):
return A @ B @ C # equivalent to (A @ B) @ c
def matmul_rl(A, B, C):
return A @ (B @ C) # explicit efficient ordering
def multi_dot(A, B, C):
return torch.linalg.multi_dot((A, B, C))
def einsum(A, B, C):
return torch.einsum('ij,jk,kl->il', A, B, C)
Aa, Bb, Cc = [], [], []
print ('generating data ...')
for i in range (10):
Aa.append (torch.randn (2 ** 16, 2 ** 12, device = 'cuda'))
Bb.append (torch.randn (2 ** 12, 2 ** 12, device = 'cuda'))
Cc.append (torch.randn (2 ** 12, 2 ** 8, device = 'cuda'))
print(f"matrices shapes: {Aa[0].shape} {Bb[0].shape} {Cc[0].shape}")
print ('run timing loops ...')
# matmul_no_compile
def run_benchmark(fn, Aa, Bb, Cc):
# warmup
for i in range(10):
for (A, B, C) in zip(Aa, Bb, Cc):
D = fn(A, B, C)
# timeing
torch.cuda.synchronize()
t0 = time.time()
for i in range (10):
for A, B, C in zip (Aa, Bb, Cc):
D = fn(A, B, C)
torch.cuda.synchronize()
t1 = time.time()
print(f'{fn.__name__}: {t1 - t0}')
run_benchmark(matmul_no_compile, Aa, Bb, Cc)
run_benchmark(matmul_compile, Aa, Bb, Cc)
run_benchmark(matmul_rl, Aa, Bb, Cc)
run_benchmark(multi_dot, Aa, Bb, Cc)
run_benchmark(einsum, Aa, Bb, Cc)
I guess the question is now: should this be something torch.compile be trying to support? I’m not familiar enough with compilation to know how difficult this would be, but I’m guessing the pattern of a chain of matmuls should be preserved down to the IR (?) and it should be possible to reorder the matmuls there (at least for simple patterns such as a linear chain).
I’m glad that you confirmed this on a more modem gpu (and hence a
different compilation backend).
That wouldn’t seem unreasonable to me. I don’t think that proposing
this as a feature request on github would be poorly received.
I don’t know anything about the internals of torch.compile, but this all
sounds plausible to me. The logic is straightforward enough (witness multi_dot()). I just don’t know whether torch.compile has hooks in
the right places to make it easy.
I don’t think, however, that chains of matrix multiplications (without
intervening non-linearities or additive biases) are that common of
a use case in pytorch, so implementing something like this – even
if relatively easy – might not be a priority.
(If you do file a feature request, I would suggest including a couple of
real-world use cases to motivate the feature.)