How to write a CUDA kernel that handles complex arguments?
I have been writing a CUDA extension having the input obtained by torch.view_as_real
function, which takes a cfloat
and represents an extra dim of float
of size two (real and imaginary components).
Here is a code snippet:
import torch
from my_cuda_extension import multiplication_complex
cuda = torch.device('cuda')
x = torch.view_as_real(torch.rand(size=(1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10)
h = torch.view_as_real(torch.rand(size=(1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10)
multiplication_complex(x, h)
This is the code I currently have inside the kernel.
template <typename scalar_t>
__device__ __forceinline__ void multiplication_complex(
scalar_t a, scalar_t b,
scalar_t c, scalar_t d,
scalar_t* out_re, scalar_t* out_im) {
*out_re += a*c - b*d;
*out_im += a*d + b*c;
}
template <typename scalar_t>
__global__ void multiplication_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t, 5, torch::RestrictPtrTraits> x,
const torch::PackedTensorAccessor32<scalar_t, 5, torch::RestrictPtrTraits> h,
torch::PackedTensorAccessor32<scalar_t, 5, torch::RestrictPtrTraits> out,
const int THREADS, const int C, const int W, const int PLANE_SIZE){
const int b = blockIdx.x;
const int f = blockIdx.y;
const int cell_idx = blockIdx.z * THREADS + threadIdx.x;
if (cell_idx >= PLANE_SIZE) return;
const int i = cell_idx/W;
const int j = fmod(cell_idx, W);
scalar_t out_re = 0.0;
scalar_t out_im = 0.0;
for (int c = 0; c < C; ++c) {
const scalar_t x_re = x[b][c][i][j][0];
const scalar_t x_im = x[b][c][i][j][1];
const scalar_t h_re = h[f][c][i][j][0];
const scalar_t h_im = h[f][c][i][j][1];
multiplication_complex(x_re, x_im, h_re, h_im, &out_re, &out_im);
}
out[b][f][i][j][0] = out_re;
out[b][f][i][j][1] = out_im;
}
Which is dispatched by
printf("About to DISPATCH\n");
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "multiplication_complex_cuda",
([&] {
multiplication_cuda_kernel<scalar_t><<<GRID_SIZE, THREADS>>>(
x.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
h.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
out.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
THREADS, C, W, PLANE_SIZE);
})
);
I want to be able to use complex numbers in a more straight-forward manner and use AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
instead. Thus, avoiding torch.view_as_real
and calling the function as
import torch
from my_cuda_extension import multiplication_complex
cuda = torch.device('cuda')
x = torch.rand(size=(1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10
h = torch.rand(size=(1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10
multiplication_complex(x, h)
The new dispatch should look like:
printf("About to DISPATCH\n");
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(x.scalar_type(), "multiplication_complex_cuda",
([&] {
multiplication_cuda_kernel<scalar_t><<<GRID_SIZE, THREADS>>>(
x.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
h.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
out.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
THREADS, C, W, PLANE_SIZE);
}));
BUT I am not able to write a multiplication_cuda_kernel
function that gets called.
template <typename scalar_t>
__global__ void multiplication_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> x,
const torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> h,
torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> out,
const int THREADS, const int C, const int W, const int PLANE_SIZE
){
const int b = blockIdx.x; // Image position in Batch
const int f = blockIdx.y; // Filter position
const int cell_idx = blockIdx.z * THREADS + threadIdx.x; // data point/pixel/cell index in h x w plane
if (cell_idx >= PLANE_SIZE) return;
if (threadIdx.x == 0) printf("Hello Block %d\n", blockIdx.x);
}
However, this Hello Block
is not being printed, and so the multiplication_cuda_kernel
function is not being called.
I have been observing this file as an example (pytorch/aten/src/ATen/native/cuda/UnaryComplexKernels.cu), in particular the snippet below.
template<typename T>
__host__ __device__ static inline c10::complex<T> angle_wrapper(c10::complex<T> v) {
return std::arg(v);
}
void angle_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "angle_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return angle_wrapper(a);
});
});
}
However, none of my attempts inspired by the snippet above were able to successfully use a kernel function that was dispatched for the complex dtype. Any help on getting the multiplication_cuda_kernel
function called and how to implement it with a complex scalar_t would be very appreciated.
Thank you for your time.