Hi,
I’m currently debugging the torch.sparse.mm
operator with different sparsity levels. I profiled the operator to see if I could get the actual CUDA kernel implementation, but I couldn’t find it.
I got the table below. It seems that csrmm_alg2_kernel
is a good candidate. Does anyone know where it is implemented?
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::addmm 2.42% 21.313us 9.89% 87.236us 43.618us 346.050us 95.82% 692.100us 346.050us 2
aten::_sparse_mm 0.49% 4.288us 62.55% 551.527us 551.527us 0.000us 0.00% 361.154us 361.154us 1
aten::_sparse_addmm 0.27% 2.374us 5.63% 49.664us 49.664us 0.000us 0.00% 346.050us 346.050us 1
void cusparse::csrmm_alg2_kernel<cusparse::CsrMMPoli... 0.00% 0.000us 0.00% 0.000us 0.000us 325.218us 90.05% 325.218us 325.218us 1
void cusparse::matrix_scalar_multiply_kernel<cuspars... 0.00% 0.000us 0.00% 0.000us 0.000us 15.872us 4.39% 15.872us 15.872us 1
aten::zeros 0.48% 4.229us 56.43% 497.575us 497.575us 0.000us 0.00% 15.104us 15.104us 1
aten::zero_ 0.39% 3.436us 15.98% 140.917us 140.917us 0.000us 0.00% 15.104us 15.104us 1
aten::fill_ 0.72% 6.312us 15.59% 137.481us 137.481us 15.104us 4.18% 15.104us 15.104us 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 15.104us 4.18% 15.104us 15.104us 1
void cusparse::csrmm_alg2_partition_kernel<128, long... 0.00% 0.000us 0.00% 0.000us 0.000us 4.960us 1.37% 4.960us 4.960us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
The code used is below, along with my environment configuration.
torch==2.6.0
nvcc==Build cuda_11.3.r11.3/compiler.29920130_0
python==3.11
NVIDIA GeForce RTX 3080
CUDA Version: 12.2
import sys
import time
import torch
from torch.profiler import profile, record_function, ProfilerActivity
def main(N: int = 10):
torch.manual_seed(123)
device_arg = sys.argv[1] if len(sys.argv) > 1 else None
device = torch.device("cuda" if torch.cuda.is_available() and device_arg == "cuda" else "cpu")
sparsity_level = [0.0, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.99]
shape = (3072, 768)
shape2 = (1024, 768)
output_shape = (3072, 1024)
for sparsity in sparsity_level:
mask = torch.rand(shape, device=device) > sparsity
inp1 = torch.rand(shape, device=device) * mask
sparse = inp1.to_sparse_csr()
inp2 = torch.rand(shape2, device=device)
total = 0
for _ in range(N):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
start = time.time()
output = torch.sparse.mm(sparse, inp2.T)
total += time.time() - start
assert output.shape == output_shape
prof.export_chrome_trace("trace.json")
print(sparsity, total/N)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
if __name__ == "__main__":
main()