C++ API Use python tensor

Good evening,

I am trying to switch fron 0.4.1 to 1.0 which means transforming my old C torch code into C++.
I am trying to use a Tensor created in python in cpp and then cuda using the at::Tensor and transforming it into a float * or int * 1D array. It perfectly worked in C doing

THFloatTensor_data(points)

and it was modifying the array such that I could use the new values in python. However in C++ I don’t know if doing :

points.data()

is the way to go as my code is compiling but the values inside the array are not changed for python.

Any idea ?
I have not compilation error what so ever and the code runs without error but again the Tensor I give to the C++ function in python is not modified as it was in C.

.data<float>() is indeed a way to get a pointer to the data. You need to take care of strides yourself, but you probably know that. For pointwise access there also is .accessor<float, dim>(). With the returned accessor you can index until you get a scalar to read/write.

Best regards

Thomas

Good morning Tom,
I am not sure I understand correctly how to take care of strides myself. My goal is to provide a pointer to the array as a float * or int * to use it in a cuda code. Here is an example of a cuda kernel function using it :

__global__ void cubeselect(int n,float radius, const float* xyz, int* idx_out)
{
    int batch_idx = blockIdx.x;
    xyz += batch_idx * n * 3;
    idx_out += batch_idx * n * 8;
    float temp_dist[8];
    float judge_dist = radius * radius;
    for(int i = threadIdx.x; i < n;i += blockDim.x) {
        float x = xyz[i * 3];
        float y = xyz[i * 3 + 1];
        float z = xyz[i * 3 + 2];
        for(int j = 0;j < 8;j ++) {
            temp_dist[j] = 1e8;
            idx_out[i * 8 + j] = i; // if not found, just return itself..
        }
        for(int j = 0;j < n;j ++) {
            if(i != j){
              float tx = xyz[j * 3];
              float ty = xyz[j * 3 + 1];
              float tz = xyz[j * 3 + 2];
              float dist = (x - tx) * (x - tx) + (y - ty) * (y - ty) + (z - tz) * (z - tz);
              if(dist <= judge_dist){
                int _x = (tx > x);
                int _y = (ty > y);
                int _z = (tz > z);
                int temp_idx = _x * 4 + _y * 2 + _z;
                if(dist < temp_dist[temp_idx]) {
                    idx_out[i * 8 + temp_idx] = j;
                    temp_dist[temp_idx] = dist;
                }
              }
            }
        }

    }
}

I previously was doing in C something like this :

float * outp    = THFloatTensor_data(out);
int   * idxp    = THIntTensor_data(idx);

I think auto outp = out.data<float>() and auto idxp = idx.data<int>() should do the trick.
You code is kernel implicitly assuming that you have a contiguous tensor. If you wanted to change that, you could pass PackedAccessors to the kernel. They can be roughly used as arrays. ATen’s native batch norm is an example that uses PackedAccessors.

Best regards

Thomas

This is what I tried but the behavior of the code between C and C++ are completely different while the data inputs are the same. I have to say that I am kind of lost concerning this issue forcing me and project partners to stick to 0.4.1.

Well, if you post the C++ extension module that comes with the kernel and input and expected result, people might take a look.

Alright the code is quite big so I am not going to paste everything.
I have in the C++ code two types of functions, the first one :

 void select_cube(at::Tensor xyz, at::Tensor idx_out, int b, int n,float radius)
{
  auto output =  idx_out.contiguous().data<int>();
  auto input  =  xyz.contiguous().data<float>();
  cubeSelectLauncher(b,n,radius,input,output);
}

It gives to the cuda code above tensors to manipulate.
The other one is a bit longer and manipulates directly in C++ a python tensor as follow :

void interpolate(int b, int n, int m,  at::Tensor  xyz1p, at::Tensor  xyz2p, at::Tensor  distp,   at::Tensor  idxp){

  float * xyz1 = xyz1p.contiguous().data<float>();
  float * xyz2 = xyz2p.contiguous().data<float>();
  float * dist = distp.contiguous().data<float>();
  int   * idx  = idxp.contiguous().data<int>();

  for (int i=0;i<b;++i) {
     for (int j=0;j<n;++j) {
   float x1=xyz1[j*3+0];
   float y1=xyz1[j*3+1];
   float z1=xyz1[j*3+2];
         double best1=1e40; double best2=1e40; double best3=1e40;
         int besti1=0; int besti2=0; int besti3=0;
         for (int k=0;k<m;++k) {
             float x2=xyz2[k*3+0];
       float y2=xyz2[k*3+1];
       float z2=xyz2[k*3+2];

 double d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
             if (d<best1) {
                 best3=best2;
                 besti3=besti2;
                 best2=best1;
                 besti2=besti1;
                 best1=d;
                 besti1=k;
             } else if (d<best2) {
                 best3=best2;
                 besti3=besti2;
                 best2=d;
                 besti2=k;
             } else if (d<best3) {
                 best3=d;
                 besti3=k;
             }
         }
         dist[j*3]=best1;
         idx[j*3]=besti1;
         dist[j*3+1]=best2;
         idx[j*3+1]=besti2;
         dist[j*3+2]=best3;
         idx[j*3+2]=besti3;
     }
     xyz1+=n*3;
     xyz2+=m*3;
     dist+=n*3;
     idx+=n*3;
 }
}

If I am transforming correctly the Tensor to float* using data(), then the error might be coming from the second type of functions. But again the code is the same as it was in C, just changed the transformation from tensor to pointer .

Actually, I’d have expected that you post something I can copypaste into python and run. The cpp_extension module allows nicely to have a complete snippet.

The problem I would immediately see is that

auto output =  idx_out.contiguous().data<int>();
auto input  =  xyz.contiguous().data<float>();

is not good.

  1. You need to hang on to the tensors you have data pointers or accessors.
  2. You need to return (and I’d recommend returning the output in select_cube) the output.
    So
auto output_ =  idx_out.contiguous(); // why would this not be torch::empty(...)?
auto output = output_.data<int>();
auto input_  =  xyz.contiguous();
auto input = input_.data<float>();
...
return output_;

Best regards

Thomas