Working on a custom torch cuda
/ cpp
extension that loads a cubin
image using the cuda driver (cuLaunchKernel
).
Are there any precautions needed when calling cuda
driver functions in the context of torch
in an extension? Getting a segfault when calling the kernel which has the following signature:
f(CUstream stream, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t M, int32_t N, int32_t K, int32_t stride_cm, int32_t stride_cn, int32_t stride_am, int32_t stride_ak, int32_t stride_bk, int32_t stride_bn)
I’m casting A
, B
, and C
from at::Tensor
to CUdeviceptr
using reinterpret_cast<CUdeviceptr>(A.data_ptr())
and using at::cuda::getCurrentStream()
for stream
.