CUDA streams not running in parallel?

Thanks for checking!

I did suspect that might be the case, so I tried it with bigger operations, like 11,000 x 11,000. However, the best I can get is an extremely marginal benefit from parallelization before I run out of GPU memory.

Do you know what the nature of the overhead is? It seems a shame that CUDA streams can’t even come close to performing as well as a single CUDA kernel. I’m dealing with a situation where it would save me a lot of memory if I could run multiple kernels in parallel rather than using just one.

test_streams.py

import argparse
import datetime

import torch

def map_reduce_serial(args, map_func, reduce_func):
    results = []
    for arg in args:
        results.append(map_func(arg))
    return reduce_func(results)

def map_reduce_parallel(device, args, map_func, reduce_func):
    results = []
    main_stream = torch.cuda.current_stream(device)
    for arg in args:
        stream = torch.cuda.Stream(device)
        #stream.wait_stream(main_stream)
        with torch.cuda.stream(stream):
            results.append(map_func(arg))
        #main_stream.wait_stream(stream)
    return reduce_func(results)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--parallel', action='store_true', default=False)
    parser.add_argument('-n', type=int, default=100000)
    parser.add_argument('--size', type=int, default=20)
    args = parser.parse_args()

    device = torch.device('cuda')

    main_stream = torch.cuda.Stream(device)
    with torch.cuda.stream(main_stream):

        W = torch.rand((args.size, args.size), device=device)

        def map_func(i):
            x = torch.rand((args.size,), device=device)
            return W * x

        def reduce_func(results):
            return torch.stack(results).sum().item()

        if args.parallel:
            map_reduce_func = lambda *args: map_reduce_parallel(device, *args)
        else:
            map_reduce_func = map_reduce_serial

        torch.cuda.reset_max_memory_allocated(device)
        start_time = datetime.datetime.now()
        result = map_reduce_func(range(args.n), map_func, reduce_func)
        torch.cuda.synchronize(device)
        duration = datetime.datetime.now() - start_time
        memory = torch.cuda.max_memory_allocated(device)
        print('result:', result)
        print('time:', duration)
        print('memory:', memory)

if __name__ == '__main__':
    main()
$ python test_streams.py -n 10 --size 11000
result: 301739680.0
time: 0:00:00.066573
memory: 10169239040
$ python test_streams.py -n 10 --size 11000 --parallel
result: 272663200.0
time: 0:00:00.065152
memory: 10169239040