I’m trying to write a batched-gemm operator where each input has different shapes, following is a script I used for profiling and debugging:
import torch
import myop
import time
m = torch.tensor([100, 128, 20, 284, 89, 10, 82, 92, 10, 20, 30, 100, 128, 20, 284, 89, 10, 82, 92, 10, 20, 30,
100, 128, 20, 284, 89, 10, 82, 92, 10, 20, 30, 100, 128, 20, 284, 89, 10, 82, 92, 10, 20, 30])
k = torch.tensor([48, 256, 33, 92, 17, 520, 48, 256, 33, 92, 17, 520, 48, 256, 33, 92, 17, 520, 33, 92, 17, 520,
48, 256, 33, 92, 17, 520, 48, 256, 33, 92, 17, 520, 48, 256, 33, 92, 17, 520, 33, 92, 17, 520])
n = torch.tensor([200, 30, 17, 30, 200, 30, 17, 30, 200, 30, 17, 30, 200, 30, 17, 30, 200, 30, 17, 30, 200, 30,
200, 30, 17, 30, 200, 30, 17, 30, 200, 30, 17, 30, 200, 30, 17, 30, 200, 30, 17, 30, 200, 30])
A_arr = []
B_arr = []
C_arr = []
for mi, ni, ki in zip(m.tolist(), n.tolist(), k.tolist()):
A_arr.append(torch.rand(mi, ki).to(0))
B_arr.append(torch.rand(ki, ni).to(0))
C_arr.append(torch.rand(mi, ni).to(0))
A = torch.cat([_.view(-1) for _ in A_arr], 0)
B = torch.cat([_.view(-1) for _ in B_arr], 0)
C = torch.cat([_.view(-1) for _ in C_arr], 0)
torch.cuda.synchronize()
tic = time.time()
for _ in range(100):
for i in range(len(m)):
Ai = A_arr[i]
Bi = B_arr[i]
Ci = Ai @ Bi
torch.cuda.synchronize()
toc = time.time()
print(toc - tic)
tic = time.time()
for _ in range(100):
myop.multiple_gemm(A, B, C, m, n, k, False, False)
torch.cuda.synchronize()
toc = time.time()
print(toc - tic)
where myop
is a C++ extension I wrote which basically calls cublasSgemm
(non-transpose for both operands) operator for each (a, b, c) in the batch sequentially and only uses a single stream.
However, I found that myop’s execution time is longer than sequentially apply torch.matmul
in Python. To investigate the reason I use nsight-system to get more insights:
I found that PyTorch only triggers sgemm_32x32x32_NN_vec
kernel while myop
trigger multiple kinds of kernels including maxwell_sgemm_128_64_nn
and sgemm_32x32x32_NN_vec
, and I can confirm that both ways produce exactly the same result. I was wondering that what’s PyTorch behavior for GEMM on GPU, has it enabled some environment variable that might force cublas to not select kernel in the default way, etc?