CUDA Stream for PyTorch C++/CUDA Custom Extension

Hello,

I was studying the “CUSTOM C++ AND CUDA EXTENSIONS” from the PyTorch official tutorial. I am not sure how PyTorch was implemented but I would assume PyTorch would use CUDA stream for asynchronous execution. However, it seems that we don’t have access to the CUDA stream when we are programming the custom extension. This means that the execution of the custom extension will always be synchronous. Can someone provide some instructions on it? Thank you very much.

Best,

Lei

Can you paste some example code? I think you should have access to CUDA stream from C++, if you can share some code and your observation. I might be able to help.

@glaringlee
Thank you, Xinyun. I am looking at this particular line of code:
https://github.com/pytorch/extension-cpp/blob/master/cuda/lltm_cuda_kernel.cu#L119.
I don’t think CUDA stream was wrapped up there?
Best,
Lei

@leimao
Yes, it doesn’t. It use the default stream in this case. But you can use your own stream by tweaking this line right?
https://github.com/pytorch/extension-cpp/blob/master/cuda/lltm_cuda_kernel.cu#L120
You can assign CUDA stream in the bracket as the 4th parameter (3rd parameter is 0).

Thank you very much for the quick response @glaringlee. I can create a CUDA stream and use it here in the extension. But it will not make sense because the CUDA stream I used for the extension is not the same as the one used in the neural network backbone. Therefore, the execution of the extension will not be truly async. Ideally, it is PyTorch’s responsibility to expose the PyTorch CUDA stream to the user.
I think TensorFlow exposes CUDA stream to the user:


Please ctrl + F and search “stream”.

That is to say. PyTorch should own the CUDA stream. The extension should not.

@leimao
No, I mean, you can use libtorch exposed api to get CUDA stream.
I wrote this recently ‘https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html’ , not finished yet, someone said it is not working well, but no follow up, can you try and let me know if it works? The other thing is that please pay attention to setCurrentStream if you plan to use it, very tricky.

Thank you very much @glaringlee . I did not know we have this API. I can certainly try it but I am not sure how should I test to see if it is running asynchronously using the CUDA stream correctly?

@leimao
I think what you can do is that for eg. get 2 streams from stream pool, pass each one to a separate kernel, and inside the kernel, check the current stream, then you know whether your kernel is running on both streams.

1 Like

Hi @leimao, I am recently working on similar problems and eventually end up in this example: flownet2.
It seems using the at::cuda::getCurrentCUDAStream() as the 4th parameter in your kernel will respect the stream setting in your network backbone (i.e. in pytorch code, guarded by with torch.cuda.stream(s):).

I tested it myself using nvidia profiler and it did work.

1 Like

Thank you very much. I will take a look.

pytorch has provided cuda stream api getCurrentCUDAStream(), please take a look.

1 Like