Building a C extension using generics for THCudaTensor and THCudaDoubleTensor

I’m working with a C extension and I’d like to support both THCudaDoubleTensor and THCudaTensor alongside using PyTorch generics. I think I could make my own defines:

#ifdef TH_DOUBLE
#define TH_TENSOR_TYPE THCudaDoubleTensor
#define TH_GEMM THCudaBlas_Dgemm
#define TH_TENSOR_TYPE THCudaTensor
#define TH_GEMM THCudaBlas_Sgemm

TH_TENSOR_TYPE *input = TH_TENSOR_TYPE_new(state);

but I was wondering what would be the most PyTorch-tonic way of doing it. Can someone give me a pointer to some code that uses PyTorch in-house generics? I’m having trouble figuring how PyTorch builds them and calls the respective function in Python.


The common way to do that is to use THGenerateFloatTypes.h.
You can take a look at this repo. Even though it was done for lua torch, it shows how to use this to generate different functions automatically for different types. You will need to add them properly into a python module instead of lua module.