Switching to bmm didn’t seem to make a big difference; I’m guessing the matmuls are a bit small utilize hardware efficiently.
Here’s a starting point if you wanted to play around with different backends and possibly change the way the input shapes/dims are ordered.
import torch
import time
torch.set_float32_matmul_precision('high')
var = torch.empty(5, 12, device='cuda')
A = torch.randn(5, 10, 200, device='cuda')
B = torch.randn(12, 60, 200, device='cuda')
#def func(var, A, B):
# return torch.matmul(B.reshape(1, 12, 60, 1, 200), A.permute(0, 2, 1).reshape(5, 1, 1, 200, 10)).squeeze().max(2).values.sum(-1)
#compiled = torch.compile(func, mode='max-autotune')
for id in range(5):
var[id] = (B @ A[id].permute(1, 0)).max(1).values.sum(-1)
C = torch.einsum('ijk,lmk->ilmj', A, B).max(2).values.sum(-1)
D = torch.matmul(B.reshape(1, 12, 60, 1, 200), A.permute(0, 2, 1).reshape(5, 1, 1, 200, 10)).squeeze().max(2).values.sum(-1)
# E = compiled(var, A, B)
print(torch.allclose(var, D, atol=1e-3, rtol=1e-3))
print(torch.allclose(var, C, atol=1e-3, rtol=1e-3))
# print(torch.allclose(var, E, atol=1e-3, rtol=1e-3))
#timing
torch.cuda.synchronize()
t1 = time.perf_counter()
for id in range(5):
var[id] = (B @ A[id].permute(1, 0)).max(1).values.sum(-1)
torch.cuda.synchronize()
t2 = time.perf_counter()
C = torch.einsum('ijk,lmk->ilmj', A, B).max(2).values.sum(-1)
torch.cuda.synchronize()
t3 = time.perf_counter()
D = torch.matmul(B.reshape(1, 12, 60, 1, 200), A.permute(0, 2, 1).reshape(5, 1, 1, 200, 10)).squeeze().max(2).values.sum(-1)
torch.cuda.synchronize()
t4 = time.perf_counter()
# E = compiled(var, A, B)
torch.cuda.synchronize()
t5 = time.perf_counter()
print(f"orig: {t2 - t1} einsum: {t3 - t2} bmm: {t4 - t3}")# compile: {t5 - t4}")
# python3 scratch10.py
True
True
orig: 0.0003075781278312206 einsum: 0.00010568089783191681 bmm: 9.42843034863472e-05