I’m following the tutorial on Custom C++ / CUDA Extensions.
In the end there are two implementations of the module:
- In C++, imported as
- In CUDA, imported as
(Note that although
lltm_cpu also runs on the GPU, CUDA is still a lot faster)
Is there a recommended way to use LLTM (or any other extension) inside my Python’s
torch.nn.Module so that
to(device) would execute
lltm_cpp in running on CPU and
lltm_cuda when running on GPU?
In more detail:
self.lltm = LLTM()
model = FinalModel()
model.to("cpu") # executes lltm_cpp
model.to("cuda") # executes lltm_cuda
Any help would be greatly appreciated.
The dispatching mechanism should already abstract away the device and should execute the corresponding code as described in the tutorial. Is this not working for you and do you see any errors after implementing the custom CUDA extension?
The dispatching mechanism seems to work only for C++ extension, but not for CUDA extension (unless I’m missing something).
setup.py looks as follows:
ext_modules = [
"lltm_cuda", sources=["lltm_cuda.cpp", "lltm_cuda_kernel.cu"],
I can successfully run
lltm_cpp on CPU and GPU. But when running
lltm_cuda on CPU i get:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: x must be a CUDA tensor
Therefore I would like to automatically run
lltm_cpp on CPU and
lltm_cuda on GPU.
My question is: is there a recommended way to dispatch these ops on the Python side?
Ah OK, I see. In this case, this Dispatching tutorial might be useful, which seems to target your use case:
The dispatcher is an internal component of PyTorch which is responsible for figuring out what code should actually get run when you call a function like torch::add. This can be nontrivial, because PyTorch operations need to handle a lot of cross-cutting concerns that are “layered” on top of one of another. Here is a sampling of some of the things it handles:
Switching between the CPU and CUDA implementations of an operator, depending on the devices of the input tensors. …
This is exactly what I needed, thank you!