I’m working with a C extension and I’d like to support both
THCudaTensor alongside using PyTorch generics. I think I could make my own defines:
#ifdef TH_DOUBLE #define TH_TENSOR_TYPE THCudaDoubleTensor #define TH_GEMM THCudaBlas_Dgemm #else #define TH_TENSOR_TYPE THCudaTensor #define TH_GEMM THCudaBlas_Sgemm #endif TH_TENSOR_TYPE *input = TH_TENSOR_TYPE_new(state);
but I was wondering what would be the most PyTorch-tonic way of doing it. Can someone give me a pointer to some code that uses PyTorch in-house generics? I’m having trouble figuring how PyTorch builds them and calls the respective function in Python.