I have the following (severely unoptimized) CUDA kernel, which takes two batched matrices and does a windowed matrix multiplication between the two. It’s not super pretty, but works (hints for how to optimize the kernel are also appreciated). However, for half precision fp16 tensors, the kernel is very slow (100x slower on an A100 GPU; see the test script and output below). I’ve found a GitHub issue, stating that fastAtomicAdd
is more efficient than gpuAtomicAdd
, I’m unable to compile the kernel with fastAtomicAdd and pytorch 1.13 because fastAtomicAdd
is not defined. If anyone knows why the kernel is so much slower for half precision and could give me a hint, that would be greatly appreciated.
#include <torch/extension.h>
#include <ATen/cuda/Atomic.cuh>
#include "utils.cuh"
#include <algorithm>
#include <cuda.h>
#include <cuda_runtime.h>
#define THREADS 512
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void window_matmul_fw_cuda_kernel(
torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> src_accessor,
torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> other_accessor,
torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> out_accessor,
int window_size,
int batch_size,
int seq_len,
int hidden_dim)
{
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int h, w, s, b, w_idx, full_window_size;
full_window_size = window_size * 2 + 1;
if (thread_idx >= batch_size * seq_len * full_window_size * hidden_dim)
{
return;
}
b = thread_idx / hidden_dim / full_window_size / seq_len % batch_size;
s = thread_idx / hidden_dim / full_window_size % seq_len;
w = s - window_size + (thread_idx / hidden_dim % full_window_size);
if (w < 0 || w >= seq_len)
return;
h = thread_idx % hidden_dim;
w_idx = w - std::min(0, s - window_size) - std::max(0, s - window_size);
gpuAtomicAddNoReturn(&out_accessor[b][s][w_idx], src_accessor[b][s][h] * other_accessor[b][h][w]);
}
torch::Tensor window_matmul_fw_cuda(torch::Tensor src, torch::Tensor other, int window_size)
{
CHECK_CUDA(src);
CHECK_CUDA(other);
CHECK_INPUT(src.dim() == other.dim());
CHECK_INPUT(src.size(0) == src.size(0));
CHECK_INPUT(src.size(-1) == other.size(-2));
CHECK_INPUT(src.size(-2) == other.size(-1));
src = src.contiguous();
other = other.contiguous();
torch::Tensor out;
auto sizes = src.sizes().vec();
sizes[2] = window_size * 2 + 1;
out = torch::zeros(sizes, src.options());
int batch_size, seq_len, hidden_dim;
batch_size = src.size(0);
seq_len = src.size(1);
hidden_dim = src.size(2);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
out.scalar_type(), "window_matmul_fw_cuda", [&]
{
auto src_accessor = src.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
auto other_accessor = other.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
auto out_accessor = out.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>();
window_matmul_fw_cuda_kernel<scalar_t><<<BLOCKS(batch_size*seq_len*(window_size * 2 + 1)*hidden_dim),THREADS>>>(
src_accessor,
other_accessor,
out_accessor,
window_size,
batch_size,
seq_len,
hidden_dim
); });
return out;
}
import torch
from window_matmul import window_matmul
BATCH_SIZE = 64
SEQ_LEN = 512
HIDDEN_DIM = 64
WINDOW_SIZE = 16
mat_1 = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM).to("cuda")
mat_2 = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM).to("cuda")
for dtype in [torch.float16, torch.float32]:
mat_1 = mat_1.to(dtype)
mat_2 = mat_2.to(dtype)
with torch.no_grad():
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA]
) as prof:
mat_3 = window_matmul(mat_1, mat_2, window_size=WINDOW_SIZE)
print("#" * 20 + f" {dtype} " + "#" * 20)
print(
prof.key_averages(group_by_input_shape=True).table(
sort_by="cuda_time_total", row_limit=10
)
)
$ python test.py
STAGE:2023-01-09 21:25:24 1706475:1706475 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-09 21:25:25 1706475:1706475 ActivityProfilerController.cpp:300] Completed Stage: Collection
#################### torch.float16 ####################
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
void window_matmul_fw_cuda_kernel<c10::Half>(at::Gen... 0.00% 0.000us 0.00% 0.000us 0.000us 697.518ms 100.00% 697.518ms 697.518ms 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 8.000us 0.00% 8.000us 8.000us 1
cudaLaunchKernel 0.01% 57.000us 0.01% 57.000us 28.500us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceSynchronize 99.99% 697.393ms 99.99% 697.393ms 697.393ms 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 697.450ms
Self CUDA time total: 697.526ms
STAGE:2023-01-09 21:25:25 1706475:1706475 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-09 21:25:25 1706475:1706475 ActivityProfilerController.cpp:300] Completed Stage: Collection
#################### torch.float32 ####################
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
void window_matmul_fw_cuda_kernel<float>(at::Generic... 0.00% 0.000us 0.00% 0.000us 0.000us 7.150ms 99.89% 7.150ms 7.150ms 1
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 8.000us 0.11% 8.000us 8.000us 1
cudaLaunchKernel 0.31% 22.000us 0.31% 22.000us 11.000us 0.000us 0.00% 0.000us 0.000us 2
cudaDeviceSynchronize 99.69% 7.106ms 99.69% 7.106ms 7.106ms 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 7.128ms
Self CUDA time total: 7.158ms