Sparse.mm CUDA kernel implementation

Hi,

I’m currently debugging the torch.sparse.mm operator with different sparsity levels. I profiled the operator to see if I could get the actual CUDA kernel implementation, but I couldn’t find it.

I got the table below. It seems that csrmm_alg2_kernel is a good candidate. Does anyone know where it is implemented?

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::addmm         2.42%      21.313us         9.89%      87.236us      43.618us     346.050us        95.82%     692.100us     346.050us             2  
                                       aten::_sparse_mm         0.49%       4.288us        62.55%     551.527us     551.527us       0.000us         0.00%     361.154us     361.154us             1  
                                    aten::_sparse_addmm         0.27%       2.374us         5.63%      49.664us      49.664us       0.000us         0.00%     346.050us     346.050us             1  
void cusparse::csrmm_alg2_kernel<cusparse::CsrMMPoli...         0.00%       0.000us         0.00%       0.000us       0.000us     325.218us        90.05%     325.218us     325.218us             1  
void cusparse::matrix_scalar_multiply_kernel<cuspars...         0.00%       0.000us         0.00%       0.000us       0.000us      15.872us         4.39%      15.872us      15.872us             1  
                                            aten::zeros         0.48%       4.229us        56.43%     497.575us     497.575us       0.000us         0.00%      15.104us      15.104us             1  
                                            aten::zero_         0.39%       3.436us        15.98%     140.917us     140.917us       0.000us         0.00%      15.104us      15.104us             1  
                                            aten::fill_         0.72%       6.312us        15.59%     137.481us     137.481us      15.104us         4.18%      15.104us      15.104us             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      15.104us         4.18%      15.104us      15.104us             1  
void cusparse::csrmm_alg2_partition_kernel<128, long...         0.00%       0.000us         0.00%       0.000us       0.000us       4.960us         1.37%       4.960us       4.960us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------ 

The code used is below, along with my environment configuration.

torch==2.6.0
nvcc==Build cuda_11.3.r11.3/compiler.29920130_0
python==3.11

NVIDIA GeForce RTX 3080
CUDA Version: 12.2
import sys
import time
import torch
from torch.profiler import profile, record_function, ProfilerActivity


def main(N: int = 10):
    torch.manual_seed(123)
    device_arg = sys.argv[1] if len(sys.argv) > 1 else None
    device = torch.device("cuda" if torch.cuda.is_available() and device_arg == "cuda" else "cpu")
    sparsity_level = [0.0, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.99]
    shape = (3072, 768)
    shape2 = (1024, 768)
    output_shape = (3072, 1024)

    for sparsity in sparsity_level:
        mask = torch.rand(shape, device=device) > sparsity
        inp1 = torch.rand(shape, device=device) * mask
        sparse = inp1.to_sparse_csr() 
        inp2 = torch.rand(shape2, device=device)

        total = 0
        for _ in range(N):
            with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
                start = time.time()
                output = torch.sparse.mm(sparse, inp2.T)
                total += time.time() - start
            assert output.shape == output_shape
        prof.export_chrome_trace("trace.json")
        print(sparsity, total/N)

    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

if __name__ == "__main__":
    main()