Compute high-dimensional matrix multiplication using the matmul operator


compute high-dimensional matrix multiplication very slow, why?

not sure why but with those dimensions:

import torch
#torch.backends.cudnn.benchmark = True
repeat = 1000

import time
with torch.no_grad():
    a = torch.rand((1, 16, 4096, 128), device=torch.device("cuda"), dtype=torch.float32)
    b = torch.rand((1, 16, 128, 4096), device=torch.device("cuda"), dtype=torch.float32)
    torch.cuda.synchronize()
    start_time = time.time()
    for i in range(repeat):
        c = torch.matmul(a, b)
    torch.cuda.synchronize()
    print(f"matmul time is: {time.time() - start_time}")
    print(c.shape)
    
    torch.cuda.synchronize()
    start_time = time.time()
    for i in range(repeat):
        d = torch.einsum('abij, abjl -> abil', a, b)
    torch.cuda.synchronize()
    print(f"einsum time is: {time.time() - start_time}")
    print(c.shape)

matmul time is: 1.6723120212554932
torch.Size([1, 16, 4096, 4096])
einsum time is: 1.6882061958312988
torch.Size([1, 16, 4096, 4096])