kevinj22
(Kevin Joseph)
May 23, 2018, 6:15pm
1
Hello,
Is it possible to transfer two ATen tensors to a single cuComplex array on the GPU without transferring the data to the CPU first?
For example:
The cpp file declaring the cuda function and pybind11 code.
#include<torch/torch.h>
// Cuda Declarations
void toComplex_Cuda(at::Tensor real, at::Tensor imag, int len);
// C++ Declarations
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void toComplex(at::Tensor real, at::Tensor imag, int len)
{
CHECK_INPUT(real);
CHECK_INPUT(imag);
toComplex_Cuda(real,imag,len);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)
{
m.def("toComplex",&toComplex,"To Complex CUDA");
}
The cu file code:
#include <ATen/ATen.h>
#include <iostream>
#include <cuComplex.h>
#include <cuda.h>
#include <cuda_runtime.h>
void toComplex_Cuda(at::Tensor real, at::Tensor imag, int len)
{
cuComplex* h_data;
cuComplex* d_data;
h_data = (cuComplex *) malloc( len*sizeof(cuComplex));
cudaMalloc((void **) &d_data, len*sizeof(cuComplex));
cudaMemset(d_data,0,len*sizeof(cuComplex));
auto realAccessor = real.accessor<float,1>();
auto imagAccessor = imag.accessor<float,1>();
printf("Loop\n");
for(int i = 0; i < len; i++)
{
d_data[i].x = realAccessor[i];
d_data[i].y = imagAccessor[i];
}
cudaMemcpy(h_data,d_data,len*sizeof(cuComplex),cudaMemcpyDeviceToHost);
printf("Second Loop\n");
for(int i = 0; i < len; i++)
{
printf("(%f, %f) \n", h_data[i].x, h_data[i].y);
}
}
On the CPU the equivalent code works as desired. When I use the following code on the GPU I get the error Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)
Does the accessor immediately go to the CPU? Is there an alternative way of doing this?
SimonW
(Simon Wang)
May 23, 2018, 9:18pm
2
I’m a bit confused. Don’t you need CUDA device functions to manipulate CUDA ptrs?
kevinj22
(Kevin Joseph)
May 23, 2018, 9:31pm
3
As in this has to be done within a function with a global declaration?
i.e something like:
__global__
void moveVals(at::TensorAccessor<float, 1> real, at::TensorAccessor<float, 1> imag, cuComplex* dst, int len)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if(i >= len)
return;
dst[i].x = real[i];
dst[i].y = imag[i];
}
for example?
I tried two variations, both of which did not work:
The issue with using at::Tensor as inputs is that it can’t fill a single float value.
The issue with using at::TensorAccessor is that its not recognizable by the device.
SimonW
(Simon Wang)
May 23, 2018, 11:02pm
4
Can you try using tensor.data_ptr()?
1 Like
kevinj22
(Kevin Joseph)
May 23, 2018, 11:28pm
5
It works now.
I use the tensor.data_ptr() which returns a void* pointer, I cast that to a float* pointer.
Here’s the final working code:
#include <ATen/ATen.h>
#include <iostream>
#include <cuComplex.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__
void moveVals(float* real, float* imag, cuComplex* dst, int len)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if(i >= len)
return;
dst[i].x = real[i];
dst[i].y = imag[i];
}
void toComplex_Cuda(at::Tensor real, at::Tensor imag, int len)
{
cuComplex* h_data;
cuComplex* d_data;
h_data = (cuComplex *) malloc( len*sizeof(cuComplex));
cudaMalloc((void **) &d_data, len*sizeof(cuComplex));
cudaMemset(d_data,0,len*sizeof(cuComplex));
float* realPtr;
float* imagPtr;
realPtr = (float *) real.data_ptr();
imagPtr = (float *) imag.data_ptr();
moveVals<<<1,len>>>(realPtr, imagPtr, d_data,len);
cudaMemcpy(h_data,d_data,len*sizeof(cuComplex),cudaMemcpyDeviceToHost);
printf("Check \n");
for(int i = 0; i < len; i++)
{
printf("(%f, %f) \n", h_data[i].x, h_data[i].y);
}
}
Thanks for your help!