Custom CUDA kernel very slow for half precision

I have the following (severely unoptimized) CUDA kernel, which takes two batched matrices and does a windowed matrix multiplication between the two. It’s not super pretty, but works (hints for how to optimize the kernel are also appreciated). However, for half precision fp16 tensors, the kernel is very slow (100x slower on an A100 GPU; see the test script and output below). I’ve found a GitHub issue, stating that fastAtomicAdd is more efficient than gpuAtomicAdd, I’m unable to compile the kernel with fastAtomicAdd and pytorch 1.13 because fastAtomicAdd is not defined. If anyone knows why the kernel is so much slower for half precision and could give me a hint, that would be greatly appreciated.

#include <torch/extension.h>
#include <ATen/cuda/Atomic.cuh>

#include "utils.cuh"
#include <algorithm>
#include <cuda.h>
#include <cuda_runtime.h>

#define THREADS 512
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void window_matmul_fw_cuda_kernel(
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> src_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> other_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> out_accessor,
    int window_size,
    int batch_size,
    int seq_len,
    int hidden_dim)
{
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int h, w, s, b, w_idx, full_window_size;
  full_window_size = window_size * 2 + 1;
  if (thread_idx >= batch_size * seq_len * full_window_size * hidden_dim)
  {
    return;
  }
  b = thread_idx / hidden_dim / full_window_size / seq_len % batch_size;
  s = thread_idx / hidden_dim / full_window_size % seq_len;
  w = s - window_size + (thread_idx / hidden_dim % full_window_size);
  if (w < 0 || w >= seq_len)
    return;
  h = thread_idx % hidden_dim;
  w_idx = w - std::min(0, s - window_size) - std::max(0, s - window_size);
  gpuAtomicAddNoReturn(&out_accessor[b][s][w_idx], src_accessor[b][s][h] * other_accessor[b][h][w]);
}

torch::Tensor window_matmul_fw_cuda(torch::Tensor src, torch::Tensor other, int window_size)
{
  CHECK_CUDA(src);
  CHECK_CUDA(other);

  CHECK_INPUT(src.dim() == other.dim());
  CHECK_INPUT(src.size(0) == src.size(0));
  CHECK_INPUT(src.size(-1) == other.size(-2));
  CHECK_INPUT(src.size(-2) == other.size(-1));

  src = src.contiguous();
  other = other.contiguous();

  torch::Tensor out;
  auto sizes = src.sizes().vec();
  sizes[2] = window_size * 2 + 1;
  out = torch::zeros(sizes, src.options());
  int batch_size, seq_len, hidden_dim;
  batch_size = src.size(0);
  seq_len = src.size(1);
  hidden_dim = src.size(2);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      out.scalar_type(), "window_matmul_fw_cuda", [&]
      {
        auto src_accessor = src.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto other_accessor = other.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto out_accessor = out.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        window_matmul_fw_cuda_kernel<scalar_t><<<BLOCKS(batch_size*seq_len*(window_size * 2 + 1)*hidden_dim),THREADS>>>(
          src_accessor,
          other_accessor,
          out_accessor,
          window_size,
          batch_size,
          seq_len,
          hidden_dim
          ); });
  return out;
}
import torch

from window_matmul import window_matmul

BATCH_SIZE = 64
SEQ_LEN = 512
HIDDEN_DIM = 64
WINDOW_SIZE = 16

mat_1 = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM).to("cuda")
mat_2 = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM).to("cuda")


for dtype in [torch.float16, torch.float32]:
    mat_1 = mat_1.to(dtype)
    mat_2 = mat_2.to(dtype)
    with torch.no_grad():
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CUDA]
        ) as prof:
            mat_3 = window_matmul(mat_1, mat_2, window_size=WINDOW_SIZE)
    print("#" * 20 + f" {dtype} " + "#" * 20)
    print(
        prof.key_averages(group_by_input_shape=True).table(
            sort_by="cuda_time_total", row_limit=10
        )
    )
$ python test.py 
STAGE:2023-01-09 21:25:24 1706475:1706475 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-09 21:25:25 1706475:1706475 ActivityProfilerController.cpp:300] Completed Stage: Collection
#################### torch.float16 ####################
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void window_matmul_fw_cuda_kernel<c10::Half>(at::Gen...         0.00%       0.000us         0.00%       0.000us       0.000us     697.518ms       100.00%     697.518ms     697.518ms             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.000us         0.00%       8.000us       8.000us             1  
                                       cudaLaunchKernel         0.01%      57.000us         0.01%      57.000us      28.500us       0.000us         0.00%       0.000us       0.000us             2  
                                  cudaDeviceSynchronize        99.99%     697.393ms        99.99%     697.393ms     697.393ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 697.450ms
Self CUDA time total: 697.526ms

STAGE:2023-01-09 21:25:25 1706475:1706475 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-09 21:25:25 1706475:1706475 ActivityProfilerController.cpp:300] Completed Stage: Collection
#################### torch.float32 ####################
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void window_matmul_fw_cuda_kernel<float>(at::Generic...         0.00%       0.000us         0.00%       0.000us       0.000us       7.150ms        99.89%       7.150ms       7.150ms             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.000us         0.11%       8.000us       8.000us             1  
                                       cudaLaunchKernel         0.31%      22.000us         0.31%      22.000us      11.000us       0.000us         0.00%       0.000us       0.000us             2  
                                  cudaDeviceSynchronize        99.69%       7.106ms        99.69%       7.106ms       7.106ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.128ms
Self CUDA time total: 7.158ms

fastAtomicAdd was the solution in the end. I was just not aware I had to import it from the at::native. For reference: here’s the updated code. One of the include statements was changed and gpuAtomicAddNoReturn was changed to at::native::fastAtomicAdd

#include <torch/extension.h>
#include <ATen/native/cuda/KernelUtils.cuh>

#include "utils.cuh"
#include <algorithm>
#include <cuda.h>
#include <cuda_runtime.h>

#define THREADS 256
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void window_matmul_fw_cuda_kernel(
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> src_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> other_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> out_accessor,
    int window_size,
    int batch_size,
    int seq_len,
    int hidden_dim,
    int out_numel)
{
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int h, w, s, b, w_idx, full_window_size, index;
  full_window_size = window_size * 2 + 1;
  if (thread_idx >= batch_size * seq_len * full_window_size * hidden_dim)
    return;
  b = thread_idx / hidden_dim / full_window_size / seq_len % batch_size;
  s = thread_idx / hidden_dim / full_window_size % seq_len;
  w = s - window_size + (thread_idx / hidden_dim % full_window_size);
  if (w < 0 || w >= seq_len)
    return;
  h = thread_idx % hidden_dim;
  w_idx = w - std::min(0, s - window_size) - std::max(0, s - window_size);
  index = b * out_accessor.stride(0) + s * out_accessor.stride(1) + w_idx * out_accessor.stride(2);
  at::native::fastAtomicAdd(out_accessor.data(), index, out_numel, src_accessor[b][s][h] * other_accessor[b][h][w], true);
}

torch::Tensor window_matmul_fw_cuda(torch::Tensor src, torch::Tensor other, int window_size)
{
  CHECK_CUDA(src);
  CHECK_CUDA(other);

  CHECK_INPUT(src.dim() == other.dim());
  CHECK_INPUT(src.size(0) == src.size(0));
  CHECK_INPUT(src.size(-1) == other.size(-2));
  CHECK_INPUT(src.size(-2) == other.size(-1));

  src = src.contiguous();
  other = other.contiguous();

  torch::Tensor out;
  auto sizes = src.sizes().vec();
  sizes[2] = window_size * 2 + 1;
  out = torch::zeros(sizes, src.options());
  int batch_size, seq_len, hidden_dim;
  batch_size = src.size(0);
  seq_len = src.size(1);
  hidden_dim = src.size(2);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      out.scalar_type(), "window_matmul_fw_cuda", [&]
      {
        auto src_accessor = src.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto other_accessor = other.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto out_accessor = out.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        window_matmul_fw_cuda_kernel<scalar_t><<<BLOCKS(batch_size*seq_len*(window_size * 2 + 1)*hidden_dim),THREADS>>>(
          src_accessor,
          other_accessor,
          out_accessor,
          window_size,
          batch_size,
          seq_len,
          hidden_dim,
          out.numel()
          ); });
  return out;
}

template <typename scalar_t>
__global__ void window_matmul_bw_cuda_kernel(
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> src_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> other_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> grad_output_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> grad_src_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> grad_other_accessor,
    int window_size,
    int batch_size,
    int seq_len,
    int hidden_dim,
    int grad_src_numel,
    int grad_other_numel)
{
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int h, w, s, b, w_idx, full_window_size, index;
  full_window_size = window_size * 2 + 1;
  if (thread_idx >= batch_size * seq_len * full_window_size * hidden_dim)
    return;
  b = thread_idx / hidden_dim / full_window_size / seq_len % batch_size;
  s = thread_idx / hidden_dim / full_window_size % seq_len;
  w = s - window_size + (thread_idx / hidden_dim % full_window_size);
  if (w < 0 || w >= seq_len)
    return;
  h = thread_idx % hidden_dim;
  w_idx = w - std::min(0, s - window_size) - std::max(0, s - window_size);
  index = b * grad_src_accessor.stride(0) + s * grad_src_accessor.stride(1) + h * grad_src_accessor.stride(2);
  at::native::fastAtomicAdd(grad_src_accessor.data(), index, grad_src_numel, other_accessor[b][h][w] * grad_output_accessor[b][s][w_idx], true);
  index = b * grad_other_accessor.stride(0) + h * grad_other_accessor.stride(1) + w * grad_other_accessor.stride(2);
  at::native::fastAtomicAdd(grad_other_accessor.data(), index, grad_other_numel, src_accessor[b][s][h] * grad_output_accessor[b][s][w_idx], true);
}

std::tuple<torch::Tensor, torch::Tensor> window_matmul_bw_cuda(
    torch::Tensor src, torch::Tensor other, int window_size, torch::Tensor grad_output)
{
  CHECK_CUDA(src);
  CHECK_CUDA(other);

  CHECK_INPUT(src.dim() == other.dim());
  CHECK_INPUT(src.size(0) == src.size(0));
  CHECK_INPUT(src.size(-1) == other.size(-2));
  CHECK_INPUT(src.size(-2) == other.size(-1));

  src = src.contiguous();
  other = other.contiguous();
  grad_output = grad_output.contiguous();

  torch::Tensor grad_src, grad_other;
  grad_src = torch::zeros(src.sizes().vec(), src.options());
  grad_other = torch::zeros(other.sizes().vec(), other.options());
  int batch_size, seq_len, hidden_dim;
  batch_size = src.size(0);
  seq_len = src.size(1);
  hidden_dim = src.size(2);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad_src.scalar_type(), "window_matmul_bw_cuda", [&]
      {
        auto src_accessor = src.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto other_accessor = other.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto grad_output_accessor = grad_output.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto grad_src_accessor = grad_src.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto grad_other_accessor = grad_other.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        window_matmul_bw_cuda_kernel<scalar_t><<<BLOCKS(batch_size*seq_len*(window_size * 2 + 1)*hidden_dim),THREADS>>>(
          src_accessor,
          other_accessor,
          grad_output_accessor,
          grad_src_accessor,
          grad_other_accessor,
          window_size,
          batch_size,
          seq_len,
          hidden_dim,
          grad_src.numel(),
          grad_other.numel()
          ); });
  return std::make_tuple(grad_src, grad_other);
}

template <typename scalar_t>
__global__ void unwindow_matmul_fw_cuda_kernel(
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> src_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> other_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> out_accessor,
    int window_size,
    int batch_size,
    int seq_len,
    int hidden_dim,
    int out_numel)
{
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int h, w, s, b, w_idx, full_window_size, index;
  full_window_size = window_size * 2 + 1;
  if (thread_idx >= batch_size * seq_len * full_window_size * hidden_dim)
    return;
  b = thread_idx / hidden_dim / full_window_size / seq_len % batch_size;
  s = thread_idx / hidden_dim / full_window_size % seq_len;
  w = s - window_size + (thread_idx / hidden_dim % full_window_size);
  if (w < 0 || w >= seq_len)
    return;
  h = thread_idx % hidden_dim;
  w_idx = w - std::min(0, s - window_size) - std::max(0, s - window_size);
  index = b * out_accessor.stride(0) + s * out_accessor.stride(1) + h * out_accessor.stride(2);
  at::native::fastAtomicAdd(out_accessor.data(), index, out_numel, src_accessor[b][s][w_idx] * other_accessor[b][w][h], true);
}

torch::Tensor unwindow_matmul_fw_cuda(torch::Tensor src, torch::Tensor other, int window_size)
{
  CHECK_CUDA(src);
  CHECK_CUDA(other);

  CHECK_INPUT(src.dim() == other.dim());
  CHECK_INPUT(src.size(0) == src.size(0));
  CHECK_INPUT(src.size(-1) == window_size * 2 + 1);
  CHECK_INPUT(src.size(-2) == other.size(-2));

  src = src.contiguous();
  other = other.contiguous();

  torch::Tensor out;
  auto sizes = other.sizes().vec();
  out = torch::zeros(sizes, src.options());
  int batch_size, seq_len, hidden_dim;
  batch_size = other.size(0);
  seq_len = other.size(1);
  hidden_dim = other.size(2);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      out.scalar_type(), "unwindow_matmul_fw_cuda", [&]
      {
        auto src_accessor = src.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto other_accessor = other.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto out_accessor = out.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();

        unwindow_matmul_fw_cuda_kernel<scalar_t><<<BLOCKS(batch_size*seq_len*(window_size * 2 + 1)*hidden_dim),THREADS>>>(
          src_accessor,
          other_accessor,
          out_accessor,
          window_size,
          batch_size,
          seq_len,
          hidden_dim,
          out.numel()
          ); });
  return out;
}

template <typename scalar_t>
__global__ void unwindow_matmul_bw_cuda_kernel(
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> src_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> other_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> grad_output_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> grad_src_accessor,
    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> grad_other_accessor,
    int window_size,
    int batch_size,
    int seq_len,
    int hidden_dim,
    int grad_src_numel,
    int grad_other_numel)
{
  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int h, w, s, b, w_idx, full_window_size, index;
  full_window_size = window_size * 2 + 1;
  if (thread_idx >= batch_size * seq_len * full_window_size * hidden_dim)
    return;
  b = thread_idx / hidden_dim / full_window_size / seq_len % batch_size;
  s = thread_idx / hidden_dim / full_window_size % seq_len;
  w = s - window_size + (thread_idx / hidden_dim % full_window_size);
  if (w < 0 || w >= seq_len)
    return;
  h = thread_idx % hidden_dim;
  w_idx = w - std::min(0, s - window_size) - std::max(0, s - window_size);
  index = b * grad_src_accessor.stride(0) + s * grad_src_accessor.stride(1) + w_idx * grad_src_accessor.stride(2);
  at::native::fastAtomicAdd(grad_src_accessor.data(), index, grad_src_numel, other_accessor[b][w][h] * grad_output_accessor[b][s][h], true);
  index = b * grad_other_accessor.stride(0) + w * grad_other_accessor.stride(1) + h * grad_other_accessor.stride(2);
  at::native::fastAtomicAdd(grad_other_accessor.data(), index, grad_other_numel, src_accessor[b][s][w_idx] * grad_output_accessor[b][s][h], true);
}

std::tuple<torch::Tensor, torch::Tensor> unwindow_matmul_bw_cuda(
    torch::Tensor src, torch::Tensor other, int window_size, torch::Tensor grad_output)
{
  CHECK_CUDA(src);
  CHECK_CUDA(other);

  CHECK_INPUT(src.dim() == other.dim());
  CHECK_INPUT(src.size(0) == src.size(0));
  CHECK_INPUT(src.size(-1) == window_size * 2 + 1);
  CHECK_INPUT(src.size(-2) == other.size(-2));

  src = src.contiguous();
  other = other.contiguous();
  grad_output = grad_output.contiguous();

  torch::Tensor grad_src, grad_other;
  grad_src = torch::zeros(src.sizes().vec(), src.options());
  grad_other = torch::zeros(other.sizes().vec(), other.options());
  int batch_size, seq_len, hidden_dim;
  batch_size = other.size(0);
  seq_len = other.size(1);
  hidden_dim = other.size(2);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad_src.scalar_type(), "unwindow_matmul_bw_cuda", [&]
      {
        auto src_accessor = src.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto other_accessor = other.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto grad_output_accessor = grad_output.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto grad_src_accessor = grad_src.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
        auto grad_other_accessor = grad_other.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();

        unwindow_matmul_bw_cuda_kernel<scalar_t><<<BLOCKS(batch_size*seq_len*(window_size * 2 + 1)*hidden_dim),THREADS>>>(
          src_accessor,
          other_accessor,
          grad_output_accessor,
          grad_src_accessor,
          grad_other_accessor,
          window_size,
          batch_size,
          seq_len,
          hidden_dim,
          grad_src.numel(),
          grad_other.numel()
          ); });
  return std::make_tuple(grad_src, grad_other);
}