Torch matmul operation with different shapes

Hello! Can someone help me with torch matmul - I have two tensors A, B of shapes (5, 10, 200) and (12, 50, 200). I want to make in one operation something like this:

    var = torch.empty(5, 12)
    for id in range(A):
        var[id] = (B @ A[id].permute(1, 0)).max(1).values.sum(-1)

I tried something like this:

var = (a.reshape(-1, 200) @ b.reshape(200, -1)).reshape(5, 10, 12, 50).permute(0, 2, 1, 3).max(-1).values.sum(-1)

but it has another result than previous one

I think the most obvious way to do it would be to use einsum:

import torch
var = torch.empty(5, 12)
A = torch.randn(5, 10, 200)
B = torch.randn(12, 60, 200)
for id in range(5):
    var[id] = (B @ A[id].permute(1, 0)).max(1).values.sum(-1)
C = torch.einsum('ijk,lmk->ilmj', A, B).max(2).values.sum(-1)
print(torch.allclose( var, C))

Einsum is too slow))

Switching to bmm didn’t seem to make a big difference; I’m guessing the matmuls are a bit small utilize hardware efficiently.
Here’s a starting point if you wanted to play around with different backends and possibly change the way the input shapes/dims are ordered.

import torch
import time
torch.set_float32_matmul_precision('high')
var = torch.empty(5, 12, device='cuda')
A = torch.randn(5, 10, 200, device='cuda')
B = torch.randn(12, 60, 200, device='cuda')

#def func(var, A, B):
#    return torch.matmul(B.reshape(1, 12, 60, 1, 200), A.permute(0, 2, 1).reshape(5, 1, 1, 200, 10)).squeeze().max(2).values.sum(-1)
#compiled = torch.compile(func, mode='max-autotune')
for id in range(5):
    var[id] = (B @ A[id].permute(1, 0)).max(1).values.sum(-1)
C = torch.einsum('ijk,lmk->ilmj', A, B).max(2).values.sum(-1)
D = torch.matmul(B.reshape(1, 12, 60, 1, 200), A.permute(0, 2, 1).reshape(5, 1, 1, 200, 10)).squeeze().max(2).values.sum(-1)
# E = compiled(var, A, B)
print(torch.allclose(var, D, atol=1e-3, rtol=1e-3))
print(torch.allclose(var, C, atol=1e-3, rtol=1e-3))
# print(torch.allclose(var, E, atol=1e-3, rtol=1e-3))

#timing
torch.cuda.synchronize()
t1 = time.perf_counter()
for id in range(5):
    var[id] = (B @ A[id].permute(1, 0)).max(1).values.sum(-1)
torch.cuda.synchronize()
t2 = time.perf_counter()
C = torch.einsum('ijk,lmk->ilmj', A, B).max(2).values.sum(-1)
torch.cuda.synchronize()
t3 = time.perf_counter()
D = torch.matmul(B.reshape(1, 12, 60, 1, 200), A.permute(0, 2, 1).reshape(5, 1, 1, 200, 10)).squeeze().max(2).values.sum(-1)
torch.cuda.synchronize()
t4 = time.perf_counter()
# E = compiled(var, A, B)
torch.cuda.synchronize()
t5 = time.perf_counter()
print(f"orig: {t2 - t1} einsum: {t3 - t2} bmm: {t4 - t3}")# compile: {t5 - t4}")
# python3 scratch10.py
True
True
orig: 0.0003075781278312206 einsum: 0.00010568089783191681 bmm: 9.42843034863472e-05

I tried something like matmul, but if we add dimensions with 1 it takes much more memory to compute than without added dimensions in einsum