With torch.cuda.graph(g): capture fails on multiple gpu

hi, I encounter a silent failure when I try to capture cuda graphs on mutiple GPUs using with torch.cuda.graph(g): context. It turns that with torch.cuda.graph(g0, stream=s): works

The problem is that only one of the two cuda graphs seems successuly catpures the warpped functions. The other graph.replay seems do nothing. By chance, I found that this problem can be fixed by explicatly create a new stream and using with torch.cuda.graph(g0, stream=s):, would someone from the community explains why and what is the expected way to use cuda graph on mutiple gpus?

more details

What I intented to do is basically using cuda graph to accerlate inplace add of two tensor list on two different GPU serparately. The following code (mostly adpoted from torch.cuda.make_graphed_callables) fails as when call g1.replay() nothing happens. the output place_holder tensor remains unchanged.

data_cuda0 = [torch.ones((1), dtype=torch.float, device="cuda:0") for _ in range(3)]
output0 = torch.zeros((1), dtype=torch.float, device="cuda:0")

data_cuda1 = [torch.ones((1), dtype=torch.float, device="cuda:1") for _ in range(3)]
output1 = torch.zeros((1), dtype=torch.float, device="cuda:1")

torch.cuda.set_device("cuda:0")
g0 = torch.cuda.CUDAGraph()

# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        output0.zero_()
        for t in data_cuda0:
            output0.add_(t)
torch.cuda.current_stream().wait_stream(s)

with torch.cuda.graph(g0):
    for t in data_cuda0:
        output0.add_(t)
output0.zero_()
g0.replay()
print(output0) # tensor([3.], device='cuda:0')

torch.cuda.set_device("cuda:1")
g1 = torch.cuda.CUDAGraph()

# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        output1.zero_()
        for t in data_cuda1:
            output1.add_(t)
torch.cuda.current_stream().wait_stream(s)

with torch.cuda.graph(g1):
    for t in data_cuda1:
        output1.add_(t)
output1.zero_()
g1.replay()
print(output1) #output tensor([0.], device='cuda:1')

Could you post a minimal, executable code snippet to reproduce this error, please?

thanks for reply. i’ve uploaded a executable demo code