About PyTorch C++ operator backend selection

I am curious about the C++ backend used by PyTorch. In version 2.0, I have noticed at least 3 kinds of C++ backends used by PyTorch——CUBLAS(The oldest backend but not open-sourced), CUTLASS(an open-sourced C++ operator library developed by NVIDIA) and Triton(a python-like language and compiler for writing highly efficient custom Deep-Learning primitives, developed by OpenAI). I want to ask some questions about this(for 2.0+ version PyTorch):

  1. What are the principles of PyTorch’s back-end selection? e.g. all of the 3 backends implement the GEMM operator, PyTorch will choose which one? Where is the source code?
  2. Since the operators have some hyper-parameters (e.g. tile size), the best hyper-parameters maybe very different according to the size of operator. So how do PyTorch determines these hyper-parameters runtimely? Where is the source code to determine these params?
    Take my profiling examples as below:

    I have profiled GEMM of (1024,1024,1024), (2048, 2048, 2048), (4096, 4096, 4096), (8192, 8192, 8192), and notice they all use CUTLASS, with different configs(e.g. different tiling sizes, for shape of 2048 is 128x128, while for shape of 4096 is 128x256).

source code:

import torch
from triton.testing import do_bench

def get_flops(N, get_kernels=True):
    A = torch.randn(N, N, device='cuda', dtype=torch.float16)
    B = torch.randn(N, N, device='cuda', dtype=torch.float16)

    def f():
        return torch.mm(A, B)

    if get_kernels:
        with torch.profiler.profile() as prof:
            f()

        for e in prof.events():
            # if True:
            if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
                print(f"{N}: {e}")
                timer = e.cuda_time/1e3
    timer = do_bench(f)
    iters_per_second = 1e3/timer
    flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
    flops_achieved = iters_per_second * flops/1e12
    print(f"{N}: {flops_achieved:.2f}TF/s")

for N in [1024, 2048, 4096, 8192]:
    get_flops(N)
4 Likes

did you figure it out?