JIT Compiling Extensions

Hi everyone,

I’m trying to use the deformable convolutions cpp extensions from the mmdetection repo without a setup.py but by compiling them just in time with torch.utils.cpp_extension.load() as suggested here. However, I’m having some trouble giving the load() function the correct path.

My folder structure is as follows:

├── dcn
│   ├── deform_conv.py
│   ├── deform_pool.py
│   ├── __init__.py
│   └── src
│       ├── deform_conv_cuda.cpp
│       ├── deform_conv_cuda_kernel.cu
│       ├── deform_pool_cuda.cpp
│       └── deform_pool_cuda_kernel.cu
└── test_dcn.py

In test_dcn.py I import the deformable convolutions with from dcn import DeformConvPack and in the file deform_conv.py I inserted the following at the top:

# deform_conv.py

from torch.utils.cpp_extension import load
deform_conv_cuda = load(name='deform_conv_cuda', sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])

# Rest of the code
# ...

I expected this to work but the compilation fails because the .cpp and .cu files cannot be found. If I instead specify the sources as 'dcn/src/deform_conv_cuda.cpp' and 'dcn/src/deform_conv_cuda_kernel.cu' it works.

Could someone explain me the logic behind this? Thank you very much =)

Basically, a working directory of process and python’s path used with module import are distinct things; as compilation is about creating child processes, and module with load() lives in entirely different directory, the former directory is used as base.

Thanks for the answer! In the mean time I came to the same conclusion and used the following snippet as a workaround. That way I can import the nn.Module defined in deform_conv.py from anywhere I want without having to worry about getting the path right. I’m just a bit surprised, that the torch.utils.cpp_extension.load() function doesn’t take care of that.

# deform_conv.py

import os
from torch.utils.cpp_extension import load
parent_dir = os.path.dirname(os.path.abspath(__file__))
sources = ['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']  # Paths of sources relative to this file
abs_sources = [os.path.join(parent_dir, source) for source in sources]  # Absolute paths of sources
deform_conv_cuda = load(name='deform_conv_cuda', sources=abs_sources)  # JIT compilation of extensions