hi, I encounter a silent failure when I try to capture cuda graphs on mutiple GPUs using with torch.cuda.graph(g):
context. It turns that with torch.cuda.graph(g0, stream=s):
works
The problem is that only one of the two cuda graphs seems successuly catpures the warpped functions. The other graph.replay
seems do nothing. By chance, I found that this problem can be fixed by explicatly create a new stream and using with torch.cuda.graph(g0, stream=s):
, would someone from the community explains why and what is the expected way to use cuda graph on mutiple gpus?
more details
What I intented to do is basically using cuda graph to accerlate inplace add of two tensor list on two different GPU serparately. The following code (mostly adpoted from torch.cuda.make_graphed_callables
) fails as when call g1.replay()
nothing happens. the output place_holder tensor remains unchanged.
data_cuda0 = [torch.ones((1), dtype=torch.float, device="cuda:0") for _ in range(3)]
output0 = torch.zeros((1), dtype=torch.float, device="cuda:0")
data_cuda1 = [torch.ones((1), dtype=torch.float, device="cuda:1") for _ in range(3)]
output1 = torch.zeros((1), dtype=torch.float, device="cuda:1")
torch.cuda.set_device("cuda:0")
g0 = torch.cuda.CUDAGraph()
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
output0.zero_()
for t in data_cuda0:
output0.add_(t)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(g0):
for t in data_cuda0:
output0.add_(t)
output0.zero_()
g0.replay()
print(output0) # tensor([3.], device='cuda:0')
torch.cuda.set_device("cuda:1")
g1 = torch.cuda.CUDAGraph()
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
output1.zero_()
for t in data_cuda1:
output1.add_(t)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(g1):
for t in data_cuda1:
output1.add_(t)
output1.zero_()
g1.replay()
print(output1) #output tensor([0.], device='cuda:1')