[Extensions] Pass non-tensor objects / maintain C++ state

I’m working with a C++ class that I would like to use from Python, during inference, to implement a low-latency GEMM operation. The class holds onto certain CUDA objects such as matrix descriptors and handles to reduce the overhead of subsequent calls (because the matrix sizes and types don’t change).

The documentation I have found for the C++ extension only shows how to pass tensors between Python and C++. Is it possible for the C++ side to “maintain state” (between forward pass iterations) so I can avoid re-initializing the descriptors and handles? Or if not, is it possible to pass generic objects from Python to C++ and back (via the PyTorch extension API) so I can initialize them the first iteration and then re-use them subsequent iterations?

If I have to re-initialize everything each iteration then this technique is no faster than the current implementation, and it’s only feasible from C++, which would greatly reduce the impact. So any help here, even vaguely so, would be much appreciated!

I’ll partially answer my own question. It looks like handles are shared / created / destroyed following these examples: cuBLAS, cuSPARSE. It would be great if I could store other state such as matrix descriptors, or efficiently pass these back/forth between Python…

Essentially, PyTorch keeps some state in global (but you need to make sure to not run into threading issues with them) or thread-local variables.
You could do the same in your extension.
Or you could pass in a uint8 tensor that you use as a buffer.

Best regards


Thanks @tom. I see the DeviceThreadHandlePool class that you use for thread-local state (e.g. handles), can you point me to an example or the class you use for global state?

And with regards to passing in a uint8 tensor for a buffer, I assume that means the extension only supports passing tensor objects to C++?

For this to work I think I would need to somehow “assign” the raw data pointer to a new object. Example: if I wanted to store a string object in the uint8 tensor buffer like this. Is there a way to do this, e.g. replace the data pointer in the tensor? I’d presumably call a C++ cleanup function from Python if calling cudaFree wasn’t the right thing to do for this raw pointer.