I’m working with a C extension and I’d like to support both THCudaDoubleTensor
and THCudaTensor
alongside using PyTorch generics. I think I could make my own defines:
#ifdef TH_DOUBLE
#define TH_TENSOR_TYPE THCudaDoubleTensor
#define TH_GEMM THCudaBlas_Dgemm
#else
#define TH_TENSOR_TYPE THCudaTensor
#define TH_GEMM THCudaBlas_Sgemm
#endif
TH_TENSOR_TYPE *input = TH_TENSOR_TYPE_new(state);
but I was wondering what would be the most PyTorch-tonic way of doing it. Can someone give me a pointer to some code that uses PyTorch in-house generics? I’m having trouble figuring how PyTorch builds them and calls the respective function in Python.