CUDA kernel function that handles Complex Tensors

How to write a CUDA kernel that handles complex arguments?


I have been writing a CUDA extension having the input obtained by torch.view_as_real function, which takes a cfloat and represents an extra dim of float of size two (real and imaginary components).

Here is a code snippet:

import torch
from my_cuda_extension import multiplication_complex 

cuda = torch.device('cuda')
x = torch.view_as_real(torch.rand(size=(1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10)
h = torch.view_as_real(torch.rand(size=(1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10)

multiplication_complex(x, h)

This is the code I currently have inside the kernel.

template <typename scalar_t>
__device__ __forceinline__ void multiplication_complex(
    scalar_t a, scalar_t b,
    scalar_t c, scalar_t d,
    scalar_t* out_re, scalar_t* out_im) {
    *out_re += a*c - b*d;
    *out_im += a*d + b*c;
}

template <typename scalar_t>
__global__ void multiplication_cuda_kernel(
    const torch::PackedTensorAccessor32<scalar_t, 5, torch::RestrictPtrTraits> x,
    const torch::PackedTensorAccessor32<scalar_t, 5, torch::RestrictPtrTraits> h,
    torch::PackedTensorAccessor32<scalar_t, 5, torch::RestrictPtrTraits> out,
    const int THREADS, const int C, const int W, const int PLANE_SIZE){

    const int b = blockIdx.x;
    const int f = blockIdx.y;

    const int cell_idx = blockIdx.z * THREADS + threadIdx.x;
    if (cell_idx >= PLANE_SIZE) return;

    const int i = cell_idx/W;
    const int j = fmod(cell_idx, W);

    scalar_t out_re = 0.0;
    scalar_t out_im = 0.0;

    for (int c = 0; c < C; ++c) {

        const scalar_t x_re = x[b][c][i][j][0];
        const scalar_t x_im = x[b][c][i][j][1];

        const scalar_t h_re = h[f][c][i][j][0];
        const scalar_t h_im = h[f][c][i][j][1];

        multiplication_complex(x_re, x_im, h_re, h_im, &out_re, &out_im);
    }

    out[b][f][i][j][0] = out_re;
    out[b][f][i][j][1] = out_im;
}

Which is dispatched by

printf("About to DISPATCH\n");
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "multiplication_complex_cuda",
  ([&] {
      multiplication_cuda_kernel<scalar_t><<<GRID_SIZE, THREADS>>>(
      x.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
      h.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
      out.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
      THREADS, C, W, PLANE_SIZE);
   })
);

I want to be able to use complex numbers in a more straight-forward manner and use AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES instead. Thus, avoiding torch.view_as_real and calling the function as

import torch
from my_cuda_extension import multiplication_complex 

cuda = torch.device('cuda')
x = torch.rand(size=(1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10
h = torch.rand(size=(1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10

multiplication_complex(x, h)

The new dispatch should look like:

printf("About to DISPATCH\n");
 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(x.scalar_type(), "multiplication_complex_cuda",
    ([&] {
        multiplication_cuda_kernel<scalar_t><<<GRID_SIZE, THREADS>>>(
        x.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
        h.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
        out.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
        THREADS, C, W, PLANE_SIZE);
    }));

BUT I am not able to write a multiplication_cuda_kernel function that gets called.

template <typename scalar_t>
__global__ void multiplication_cuda_kernel(
    const torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> x,
    const torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> h,
    torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> out,
    const int THREADS, const int C, const int W, const int PLANE_SIZE
){

    const int b = blockIdx.x; // Image position in Batch
    const int f = blockIdx.y; // Filter position

    const int cell_idx = blockIdx.z * THREADS + threadIdx.x; // data point/pixel/cell index in h x w plane
    if (cell_idx >= PLANE_SIZE) return;

    if (threadIdx.x == 0) printf("Hello Block %d\n", blockIdx.x);
}

However, this Hello Block is not being printed, and so the multiplication_cuda_kernel function is not being called.


I have been observing this file as an example (pytorch/aten/src/ATen/native/cuda/UnaryComplexKernels.cu), in particular the snippet below.

template<typename T>
__host__ __device__ static inline c10::complex<T> angle_wrapper(c10::complex<T> v) {
  return std::arg(v);
}

void angle_kernel_cuda(TensorIteratorBase& iter) {
  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "angle_cuda", [&]() {
    gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
      return angle_wrapper(a);
    });
  });
}

However, none of my attempts inspired by the snippet above were able to successfully use a kernel function that was dispatched for the complex dtype. Any help on getting the multiplication_cuda_kernel function called and how to implement it with a complex scalar_t would be very appreciated.

Thank you for your time.

Wondering if @tom would be available to shed some light on this.

If you want people to look at it, I would think that providing a complete cpp extension or ideally a Python script that loads (CUDA) csrc with something like
test_ext = torch.utils.cpp_extension.load_inline("test_ext", "", csrc, is_python_module=True, verbose=True) and runs it would lower the burden of getting your snippets to run. This way, one can dive right into the matter without (re-) writing the glue code.

I’m a bit doubtful about the code being able to compile but not run the function, something seems to be funny about that, but I’m hoping that you’d show a complete C++ snippet to look at. :wink:

Best regards

Thomas

Hello Thomas, thank you for your feedback.

I just uploaded the complete code here for people interested into taking a look.

In the link above I had the current working version of this extension. And I want to change it so that I work with the complex numbers in a “cleaner” manner by

  1. Being able to call the complex_multiplication(x, h) with cfloat types.
  2. Handle complex types (somehow) with scalar_t template on the cuda kernel. Something similar to the angle_wrapper function from Pytorch mentioned above.

Eduardo.

If you allow me to re-post the code here directly, the following seems to work for me for a PyTorch git checkout from today:

import torch
import torch.utils.cpp_extension

csrc = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>


template <typename scalar_t>
__device__ __forceinline__ void elementwise_operation(
        scalar_t a,
        scalar_t b,
        scalar_t* out
    ) {
    *out += a*b;
}

template <typename scalar_t>
__global__ void complex_multiplication_cuda_kernel_v1(
        const torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> x,
        const torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> h,
        torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> out,
        const int THREADS, const int C, const int W, const int PLANE_SIZE
    ){

    const int b = blockIdx.x; // Image position in Batch
    const int f = blockIdx.y; // Filter position

    const int cell_idx = blockIdx.z * THREADS + threadIdx.x; // data point/pixel/cell index in h x w plane
    if (cell_idx >= PLANE_SIZE) return;

    const int i = cell_idx/W;
    const int j = fmod(cell_idx, W);

    scalar_t out_ = 0.0;

    /****************************************************************
     * Dimensions should be 
     * x   -> (B, C, H, W) ~> Each b is size C*H*W*I
     * h   -> (F, C, H, W) ~>      f is  ""
     * out -> (B, F, H, W) ~>      b is  ""  F*H*W*I
     ****************************************************************/
    for (int c = 0; c < C; ++c) {

        const scalar_t x_ = x[b][c][i][j];

        const scalar_t h_ = h[f][c][i][j];

        elementwise_operation(x_, h_, &out_);
    }

    out[b][f][i][j] = out_;
}

/**
 * Multiplies two tensors of Complex Tensors
 * @param x
 * @param h
 * @param output
 */
at::Tensor complex_multiplication_cuda_v1(at::Tensor x, at::Tensor h) {
    const int THREADS = 1024;

    const int B = x.size(0);
    const int F = h.size(0);
    const int C = x.size(1);
    const int H = x.size(2);
    const int W = x.size(3);
    const int PLANE_SIZE = H*W;

    const auto Z = (H*W + THREADS - 1)/THREADS;
    const dim3 GRID_SIZE(B, F, Z);

    auto output = torch::zeros(
        {B, F, H, W},
        x.options()
        );

    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(x.scalar_type(), "complex_multiplication_cuda_v1",
        ([&] {
            complex_multiplication_cuda_kernel_v1<scalar_t><<<GRID_SIZE, THREADS>>>(
                x.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
                h.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
                output.packed_accessor32<scalar_t, 4, torch::RestrictPtrTraits>(),
                THREADS, C, W, PLANE_SIZE
            );
        })
    );

    return output;
}

/**
 * Multiplies two tensors of Complex Tensors
 * @param x
 * @param h
 * @return ouput
 */
at::Tensor complex_multiplication(at::Tensor x, at::Tensor h) {
    printf("CPP: my_pytorch_extensions_cuda called\\n");
    return complex_multiplication_cuda_v1(x, h);
}

/********************************
 * Binding Functions
 ********************************/

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("complex_multiplication", &complex_multiplication, "CUDA kernel for multiplication of complex tensors");
}
"""

test_ext = torch.utils.cpp_extension.load_inline("test_ext", "", csrc, is_python_module=True, verbose=True)

torch.set_printoptions(precision=2, linewidth=120)
cuda = torch.device('cuda')
x = torch.rand((1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10
h = torch.rand((1, 1, 4, 4), dtype=torch.cfloat, device=cuda)*10

output = test_ext.complex_multiplication(x, h)

print((x[:, None] * h[None, :]).sum(2))
print(output)

This gives me the same complex tensor twice, which seemed your aim.
What does (and should) happen behind the scenes is that you’re using PyTorch’s c10::complex<float>/c10::complex<double> and the operator overloads defined for that provide the *, +, and assignment used in the kernel. (I’m using gcc 10 as the compiler, not sure if it matters, PyTorch currently triggers a bad thing in gcc 11.) Of course, in this setup, you should be using c10::complex throughout, and real() and imag() to get and set real and imaginary parts individually. (see the documentation for c10::complex for more details), and they are not a separate dimension.

Note that complex support was incremental, so I would recommend to try with a rather recent PyTorch version (e.g. nightly or testing) if you are not using it already.

Best regards

Thomas

1 Like

Thank you for your code snippet. I am using torch.__version == 1.10.0 and it worked nicely.
I am not sure why the code I had previously didn’t work. It was probably something wrong on how I used the complex types on the kernel functions.

For the sake of example completeness, here are the kernel functions I tested for manipulating the complex numbers. In the example below calculate_multiplication_long is overloaded, which could potentially have a different behaviour from the real types.

template <typename scalar_t>
__device__ __forceinline__ void calculate_multiplication(
    scalar_t x, scalar_t h, scalar_t* out
    ) {
    *out += x*h;
}

template <typename scalar_t>
__device__ __forceinline__ void calculate_multiplication_long(
    scalar_t x, scalar_t h, scalar_t* out
    ) {
    *out += x*h;
}

template <typename T>
__device__ __forceinline__ void calculate_multiplication_long(
    c10::complex<T> x, c10::complex<T> h, c10::complex<T>* out
    ) {
    const T a=x.real(), b=x.imag(), c=h.real(), d=h.imag();
    out->real(out->real() + a*c - b*d);
    out->imag(out->imag() + a*d + b*c);
}

A bit offtopic - but did you build it with new CUDA 11.6? If yes can you share your build configuration? Becouse I have a CUDA C++ kernel I would like to use in Python Pytorch Code but I have configuration problems.

No, the machine I am using happens to be CUDA Version: 10.2, not 11.6.

1 Like