CUDA streams not running in parallel?

I’m trying to use CUDA streams to run n operations in parallel on a single GPU. However, based on a simple test script I wrote, running 100,000 matrix-vector multiplications in parallel streams actually takes longer than running them in serial. I’m hoping that someone can explain to me why this is the case. Does it have to do with the fact that the default stream synchronizes with all other streams?

test_streams.py:

import argparse
import datetime

import torch

def map_reduce_serial(args, map_func, reduce_func):
    results = []
    for arg in args:
        results.append(map_func(arg))
    return reduce_func(results)

def map_reduce_parallel(device, args, map_func, reduce_func):
    results = []
    main_stream = torch.cuda.current_stream(device)
    for arg in args:
        stream = torch.cuda.Stream(device)
        stream.wait_stream(main_stream)
        with torch.cuda.stream(stream):
            results.append(map_func(arg))
        main_stream.wait_stream(stream)
    return reduce_func(results)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--parallel', action='store_true', default=False)
    parser.add_argument('-n', type=int, default=100000)
    args = parser.parse_args()

    device = torch.device('cuda')

    W = torch.rand((20, 20), device=device)

    def map_func(i):
        x = torch.rand((20,), device=device)
        return W * x

    def reduce_func(results):
        return torch.stack(results).sum().item()

    if args.parallel:
        map_reduce_func = lambda *args: map_reduce_parallel(device, *args)
    else:
        map_reduce_func = map_reduce_serial

    torch.cuda.reset_max_memory_allocated(device)
    start_time = datetime.datetime.now()
    result = map_reduce_func(range(args.n), map_func, reduce_func)
    torch.cuda.synchronize(device)
    duration = datetime.datetime.now() - start_time
    memory = torch.cuda.max_memory_allocated(device)
    print('result:', result)
    print('time:', duration)
    print('memory:', memory)

if __name__ == '__main__':
    main()

Running in serial:

$ python test_streams.py
result: 9972008.0
time: 0:00:02.027303
memory: 364826624

Running in parallel:

$ python test_streams.py --parallel
result: 9594629.0
time: 0:00:05.606621
memory: 364826624

If you wait for the main stream at every step, you don’t really run in parallel right?

.wait_stream() is actually not synchronous; it just means that future operations submitted to a stream will not execute until after another stream’s currently queued operations have completed.

https://pytorch.org/docs/stable/cuda.html#torch.cuda.Stream.wait_stream

It’s just meant to prevent the parallel streams from running until W has been initialized.

If I take out the wait_stream() calls, it still takes longer anyway:

result: 9894706.0
time: 0:00:04.524583
memory: 364826624

After chatting with some cuda experts, the answer is:
For small 20x20 operations, the overhead of the multistream operation will be larger than running everything on the same stream. There’s a narrow zone where streams actually help, but 20x20 is not it, things like that should be handled by batches kernels.

Thanks for checking!

I did suspect that might be the case, so I tried it with bigger operations, like 11,000 x 11,000. However, the best I can get is an extremely marginal benefit from parallelization before I run out of GPU memory.

Do you know what the nature of the overhead is? It seems a shame that CUDA streams can’t even come close to performing as well as a single CUDA kernel. I’m dealing with a situation where it would save me a lot of memory if I could run multiple kernels in parallel rather than using just one.

test_streams.py

import argparse
import datetime

import torch

def map_reduce_serial(args, map_func, reduce_func):
    results = []
    for arg in args:
        results.append(map_func(arg))
    return reduce_func(results)

def map_reduce_parallel(device, args, map_func, reduce_func):
    results = []
    main_stream = torch.cuda.current_stream(device)
    for arg in args:
        stream = torch.cuda.Stream(device)
        #stream.wait_stream(main_stream)
        with torch.cuda.stream(stream):
            results.append(map_func(arg))
        #main_stream.wait_stream(stream)
    return reduce_func(results)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--parallel', action='store_true', default=False)
    parser.add_argument('-n', type=int, default=100000)
    parser.add_argument('--size', type=int, default=20)
    args = parser.parse_args()

    device = torch.device('cuda')

    main_stream = torch.cuda.Stream(device)
    with torch.cuda.stream(main_stream):

        W = torch.rand((args.size, args.size), device=device)

        def map_func(i):
            x = torch.rand((args.size,), device=device)
            return W * x

        def reduce_func(results):
            return torch.stack(results).sum().item()

        if args.parallel:
            map_reduce_func = lambda *args: map_reduce_parallel(device, *args)
        else:
            map_reduce_func = map_reduce_serial

        torch.cuda.reset_max_memory_allocated(device)
        start_time = datetime.datetime.now()
        result = map_reduce_func(range(args.n), map_func, reduce_func)
        torch.cuda.synchronize(device)
        duration = datetime.datetime.now() - start_time
        memory = torch.cuda.max_memory_allocated(device)
        print('result:', result)
        print('time:', duration)
        print('memory:', memory)

if __name__ == '__main__':
    main()
$ python test_streams.py -n 10 --size 11000
result: 301739680.0
time: 0:00:00.066573
memory: 10169239040
$ python test_streams.py -n 10 --size 11000 --parallel
result: 272663200.0
time: 0:00:00.065152
memory: 10169239040

I am not sure about the overhead but I would guess that the sync is expensive and there is potentially not that much chance to actually run stuff at the same time on the gpu.

I am not sure to see how streams could help you reduce memory usage though?

My model involves a triangular matrix, where all entries below the diagonal are considered to be zero. If I use a single kernel, I need to store the matrix as a dense, square matrix with zeros, because PyTorch kernels do not operate on “jagged” tensors; if I can use multiple kernels, then I can store the matrix as columns of varying length, without needing to take up memory for the zeros.

I could get into more detail, but that would be way outside the scope of my original question.

Here’s an interesting plot that illustrates how poorly the CUDA streams are performing within my application. My model is designed such that if the CUDA streams are fully parallelized, the time complexity should drop from cubic to quadratic. However, the orange line (the CUDA streams implementation) still has the same time complexity as the serial version. On the other hand, a single-kernel version that I implemented (green line) has a time complexity much closer to 2.

image

Sadly, the memory usage goes up quite a bit as a result:

image

Can you share more details of what is the function your kernel is doing?

Unfortunately not, as it is an active research project that is yet to be published.

If the overhead of launching a stream is so high, would it help if the CUDA streams were launched from different threads?

If the overhead of launching a stream is so high, would it help if the CUDA streams were launched from different threads?

I have to admit I don’t know. cc @ngimel

Is there any chance you could write your function to be applied on a 1D Tensor that contains the upper triangular part of the matrix? that would reduce the memory and you can use a single kernel.

Launching cuda streams from the different threads would help a little (remember, you cannot fruitfully do it from python because of GIL, so you’d have to do it in a cpp extension), but the effect will still be very limited, because some of the work necessary to launch the kernels will be serialized by the driver.
The general ballpark is still that having a single kernel is much more efficient than launching kernels to the streams (that’s how e.g. all batched matrix multiplies are implemented, launching single matrix multiplies to the different streams does not even come close to the performance of a single kernel). Once the work performed by a single stream is large enough to fully occupy the GPU, there’s also no benefit from streams because they are essentially serialized, the next one will start only after the first one completes and frees the GPU.
I’d vote for Alban’s suggestion, if you can write your function to be applied to a 1D tensor that’s the way to go.

Thank you for your reply. My function is not quite simple enough to be converted to a 1D tensor operation, but I will consider writing my own CUDA extension.

Greetings,

I suggest checking out numba which has support for cuda, it prevents you the hassle of having to deal with c, its been very useful on my research

1 Like

Thanks for the tip, looks interesting!