I have an application which does torch.matmul on large tensors. Typical dimensions for my use cases are like A (32, 3072), B (3072, 4_000_000), where 32 is the batch size M, 3072 is the embedding dimension K, and 4_000_000 is N. Inputs are all fp16 (half). Other dimensions are M 32, K 512, N 16M etc.
I am using torch 2.7. I see from a profile when run on a A100, that it is using cutlass sm75 kernels like cutlass_tensorop_f16_s1688gemm_f16_256x128_32x2_nn_align1 for ~65% of the profile. Remaining 35% of the time it is using cutlass sm80 kernels like cutlass_tensorop_f16_s16816gemm_f16_128x64_64x3_nt_align8. The profile is taken during a benchmark run of my app with 1000 requests. For a single request it either picks a sm75 kernel or a sm80 kernel. We choose the request tensors using torch.rand.
I am looking for some insights on the kernel selection logic for torch. Specifically, I see significant perf difference (latency, TFLOPS) by running the cutlass sm80 kernels directly. I am looking for ways in which I can influence torch to always select sm80 kernels on A100 and sm90 kernels on H100.