Weired cublas gemm kernel calling

Hi,

The code fragment below is extracted from my project. When I profile the while project(the model has been export as a TorchScript, and I called it with libtorch) with nsys, it show me that the lin1 here called cublasSgemmStridedBatched and ampere_sgemm_128x128_tt kernel to perform gemm, which has a huge performance impact.

But when I ran this code fragment directly, nsys show me cublasSgemm_v2 and ampere_sgemm_128x64_tn kernel was called. This kernel is much faster than ampere_sgemm_128x128_tt.

lin1 = Linear(128, 128, bias=False, device='cuda')
node_vector = torch.randn((27648, 128, 3), device='cuda') # pseudo input, same size as real input
v_tranpose = node_vector.transpose(1, 2)
v_u = lin1(v_tranpose).transpose(1, 2)

I wonder how does pytorch determine the kernel type? How to call the right kernel in my project?

Your current code will fail if the transposes are removed so could you show the code calling into different kernels using the “same” workload?

v_tranpose = node_vector.transpose(1, 2)
v_u = lin1(v_tranpose).transpose(1, 2)

is exactly the code calling into different kernels.

Your code shows a single linear layer, calling into a single kernel, which is ampere_sgemm_128x64_tn in my setup.

It will be 128x128_tt kernel in my project, it is the code snippet in my project, where lin1 and lin2 are Linear(128, 128, bias=False), pu and po are the wrapper of nvtx_range_push and nvtx_range_pop

The nsys result shows:

maybe there are too many complex operations before this Linear?