Custom cuda matrix mult is very slow

I’m trying to write a “fused” cuda kernel for low rank matrix multiplication.
This is the benchmarking code:

import torch
import torch.utils.benchmark as benchmark

n, m, k = 5000, 4000, 1000
A = torch.randn(n, k, device='cuda')
B = torch.randn(k, m, device='cuda')
v = torch.randn(m, device='cuda')

t0 = benchmark.Timer(
    stmt='lowrank.multiply_low_rank_matrix(A, B, v)',
    setup='''
        from torch.utils.cpp_extension import load
        lowrank = load('lowrank', ['lowrank.cpp', 'lowrank.cu'])
        print('Done')
        ''',
    globals=dict(A=A, B=B, v=v))

t1 = benchmark.Timer(
    stmt='A @ (B @ v)',
    globals=dict(A=A, B=B, v=v))

print(t0.timeit(100))
print(t1.timeit(100))

When I run it, I get that the naive version A @ (B @ v) is about 16 times faster:

$ py lowrank_bench.py

<torch.utils.benchmark.utils.common.Measurement object at 0x7f5aa667ab30>
lowrank.multiply_low_rank_matrix(A, B, v)
setup:
  from torch.utils.cpp_extension import load
  lowrank = load('lowrank', ['lowrank.cpp', 'lowrank.cu'])

  4.76 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5aa6678910>
A @ (B @ v)
  249.51 us
  1 measurement, 100 runs , 1 thread

I’m wondering if I’m doing something wrong with benchmarking, or if my code skills are just bad.
My kernel looks roughly like the code below.
It creates an intermediate vector for Bv (the same way I assume pytorch A@(B@v) has to), and uses a call to __syncthreads() to make sure Bv is computed before it starts multiplying with A.

template <typename scalar_t>
__global__ void multiply_low_rank_matrix_cuda_kernel(
    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits> A,
    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits> B,
    const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits> vector,
    torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits> result) {

  const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
  __shared__ scalar_t intermediate_vector[1024]; // Shared memory for intermediate vector

  // Phase 1: Each thread computes one element of the intermediate vector (Bv)
  if(thread_id < B.size(0)) {
    intermediate_vector[thread_id] = 0;
    for (int j = 0; j < B.size(1); ++j) {
      intermediate_vector[thread_id] += B[thread_id][j] * vector[j];
    }
  }
  __syncthreads(); // Synchronize threads to ensure the intermediate vector is fully computed

  // Phase 2: Each thread computes one element of the final output vector (Av)
  if(thread_id < A.size(0)) {
    result[thread_id] = 0;
    for (int j = 0; j < A.size(1); ++j) {
      result[thread_id] += A[thread_id][j] * intermediate_vector[j];
    }
  }
}

My end goal is to write something more advanced with cuda, so I want to make sure I’m following best practices and not doing anything stupid at this stage.

I would assume your implementation isn’t optimal and you could check e.g. this blog post which explains tiles and quantization effects and e.g. the shared memory section of the CUDA programming guide for more information.

Did you figure out the problem? I’m facing the same issue!