Using Tensor functions in CUDA kernels

HI, I’m new to using the C++ API for PyTorch.

In my current application, I have a 256 by 14 by 4 by 4 tensor that needs to be dealt with. The application logic is something like this: For each 14 by 4 by 4 tensor in the larger tensor, execute some function f that produces some cumulative matrix product of 4 by 4 tensors, so the output of f will still be a 14 by 4 by 4 tensor. Since this is a time critical application, I wish that I can have 256 threads to run the cumulative matrix product on GPU in parallel. However, I found that I cannot directly pass Tensor into CUDA kernels, and I cannot call mm method since it is not a __device__ function.

I’m wondering is there any abstractions in PyTorch allows doing so. Thanks.

You can pass the data pointer to the kernel instead of the tensor object via<scalar_t>() as described in this part of the Custom Extensions tutorial.

I understand that I can do it, but this returns raw data pointer. I would like to have some tensors on the GPU with capabilities of mm and etc. will itself call a CUDA kernel, which uses data pointers to perform the actual operation and you can call torch functions from the C++ functions. The CUDA kernel is the wrong place to call into PyTorch.

Thanks. Is there any chance for having tensor functions as well as parallelism in a higher level?

Could you explain this feature a bit more, please?
As already mentioned, torch:mm will accept tensors and will launch a CUDA kernel performing the matrix multiplication using parallelism on the GPU (assuming your tensors are stored on the GPU).
Custom CUDA kernels can be used in case you want to write CUDA code (working on the data via pointers) and launch it via your custom function call on accepting tensors.
What would “having tensor functions as well as parallelism in a higher level” mean?

In fact, I would like to operate on the level of matrices in CUDA kernels. In the application, I would like to compute the cumulative product of a series of matrices, i.e., out[i] = prod(input, 0:i) where each entry would be a matrix, and input, output are vectors of matrices. In addition, I have 256 such problems to solve in parallel.