compute high-dimensional matrix multiplication very slow, why?
not sure why but with those dimensions:
import torch
#torch.backends.cudnn.benchmark = True
repeat = 1000
import time
with torch.no_grad():
a = torch.rand((1, 16, 4096, 128), device=torch.device("cuda"), dtype=torch.float32)
b = torch.rand((1, 16, 128, 4096), device=torch.device("cuda"), dtype=torch.float32)
torch.cuda.synchronize()
start_time = time.time()
for i in range(repeat):
c = torch.matmul(a, b)
torch.cuda.synchronize()
print(f"matmul time is: {time.time() - start_time}")
print(c.shape)
torch.cuda.synchronize()
start_time = time.time()
for i in range(repeat):
d = torch.einsum('abij, abjl -> abil', a, b)
torch.cuda.synchronize()
print(f"einsum time is: {time.time() - start_time}")
print(c.shape)
matmul time is: 1.6723120212554932
torch.Size([1, 16, 4096, 4096])
einsum time is: 1.6882061958312988
torch.Size([1, 16, 4096, 4096])