Cuda Streams in C++ Lib?

I want to do something like this in C++

s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
# Initialise cuda tensors here. E.g.:
A = torch.rand(1000, 1000, device = ‘cuda’)
B = torch.rand(1000, 1000, device = ‘cuda’)
# Wait for the above tensors to initialise.
    C =, A)
    D =, B)
# Wait for C and D to be computed.
# Do stuff with C and D.

What is the right way to do this? Say I already have a cudaStream allocated, and I want to do torch::mm, how do I do this?

at::cuda::CUDAStreamGuard guard(stream) is what you are looking for. Example:

1 Like

so just by declaring at::cuda::CUDAStreamGuard guard(stream), future torch operations will run on that stream? And if I declare my code in {} and it leaves it, it goes back to the default stream?

I tried to run my code with and without the CUDAStreamGuard commented out and the runtime is the same:

std::vector<at::cuda::CUDAStream> streams;

        torch::mm_out(X, A, B);
        at::cuda::CUDAStreamGuard g0(streams[0]);
        torch::mm_out(Y, A, C);
        at::cuda::CUDAStreamGuard g1(streams[1]);
        torch::mm_out(Z, A, D);
        at::cuda::CUDAStreamGuard g2(streams[2]);
        torch::mm_out(W, A, E);

What am I doing wrong

CUDA streams don’t magically speed-up your code. They won’t speed-up most use cases. Mostly they are useful for overlapping copying with compute. They may also be useful if your kernels do not use all the SMs and are not all completely memory bound.

Most likely this is not the case for your code, but it’s hard to guess what the limiting factor is because you don’t specify X, A, B, etc. or measuretime().

Okay, I am yet to profile it with the NVIDIA Profiler…but i eventually will. I assumed there would be a speedup since they dont exactly depend on each other and are reasonably large matrices. Removing each mult operation most definitely cuts 0.1ms off each time, and overall it takes 0.4ms whether I had those streamguard blocks or not.

I just wanted to make sure that I am using the API correctly

Could you give me some advise?Thanks very much