Custom cuda kernel using for backward pass using numba?

I am looking to write a custom kernel for the backward pass in my model to speedup training. The tutorial posted on PyTorch website goes into detail writing c++ and cuda code

Given that numba is pretty straightforward for writing cuda kernels, would there be any performance hurdle in doing so? It seems like data on the GPU can be passed between numba and pytorch without much performance issue.

Has anyone tried this? Does anyone have some words of warning against doing so?


I haven’t compared any implementations and are not aware of comparisons, but would recommend to just go for it. I’m not aware of any specific limitations and would probably stick to what would work best for you.