The code fragment below is extracted from my project. When I profile the while project(the model has been export as a TorchScript, and I called it with libtorch) with nsys, it show me that the lin1 here called cublasSgemmStridedBatched and ampere_sgemm_128x128_tt kernel to perform gemm, which has a huge performance impact.
But when I ran this code fragment directly, nsys show me cublasSgemm_v2 and ampere_sgemm_128x64_tn kernel was called. This kernel is much faster than ampere_sgemm_128x128_tt.
lin1 = Linear(128, 128, bias=False, device='cuda')
node_vector = torch.randn((27648, 128, 3), device='cuda') # pseudo input, same size as real input
v_tranpose = node_vector.transpose(1, 2)
v_u = lin1(v_tranpose).transpose(1, 2)
I wonder how does pytorch determine the kernel type? How to call the right kernel in my project?
It will be 128x128_tt kernel in my project, it is the code snippet in my project, where lin1 and lin2 are Linear(128, 128, bias=False), pu and po are the wrapper of nvtx_range_push and nvtx_range_pop