Thanks for checking!
I did suspect that might be the case, so I tried it with bigger operations, like 11,000 x 11,000. However, the best I can get is an extremely marginal benefit from parallelization before I run out of GPU memory.
Do you know what the nature of the overhead is? It seems a shame that CUDA streams can’t even come close to performing as well as a single CUDA kernel. I’m dealing with a situation where it would save me a lot of memory if I could run multiple kernels in parallel rather than using just one.
test_streams.py
import argparse
import datetime
import torch
def map_reduce_serial(args, map_func, reduce_func):
results = []
for arg in args:
results.append(map_func(arg))
return reduce_func(results)
def map_reduce_parallel(device, args, map_func, reduce_func):
results = []
main_stream = torch.cuda.current_stream(device)
for arg in args:
stream = torch.cuda.Stream(device)
#stream.wait_stream(main_stream)
with torch.cuda.stream(stream):
results.append(map_func(arg))
#main_stream.wait_stream(stream)
return reduce_func(results)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--parallel', action='store_true', default=False)
parser.add_argument('-n', type=int, default=100000)
parser.add_argument('--size', type=int, default=20)
args = parser.parse_args()
device = torch.device('cuda')
main_stream = torch.cuda.Stream(device)
with torch.cuda.stream(main_stream):
W = torch.rand((args.size, args.size), device=device)
def map_func(i):
x = torch.rand((args.size,), device=device)
return W * x
def reduce_func(results):
return torch.stack(results).sum().item()
if args.parallel:
map_reduce_func = lambda *args: map_reduce_parallel(device, *args)
else:
map_reduce_func = map_reduce_serial
torch.cuda.reset_max_memory_allocated(device)
start_time = datetime.datetime.now()
result = map_reduce_func(range(args.n), map_func, reduce_func)
torch.cuda.synchronize(device)
duration = datetime.datetime.now() - start_time
memory = torch.cuda.max_memory_allocated(device)
print('result:', result)
print('time:', duration)
print('memory:', memory)
if __name__ == '__main__':
main()
$ python test_streams.py -n 10 --size 11000
result: 301739680.0
time: 0:00:00.066573
memory: 10169239040
$ python test_streams.py -n 10 --size 11000 --parallel
result: 272663200.0
time: 0:00:00.065152
memory: 10169239040