Hi everyone,
I am trying to optimize my PyTorch training pipeline by overlapping host-to-device memory transfers with GPU computation. Below is a minimal example illustrating what I am attempting to do:
import torch
from torch.profiler import profile, ProfilerActivity
# Handler function to save the trace using a unique step number
def chrome_trace_handler(p):
# This function is automatically called when a trace is ready
# It saves the trace to a file named trace_1.json, trace_2.json, etc.
p.export_chrome_trace(f"async_trace_{p.step_num}.json")
print(f"Trace saved for step {p.step_num}")
device = torch.device("cuda")
big_mat_a = torch.rand(9999, 9999, pin_memory=True)
big_mat_b = torch.rand(9999, 9999, pin_memory=True)
big_mat_c = torch.rand(9999, 9999, pin_memory=True)
big_mat_d = torch.rand(9999, 9999, pin_memory=True)
copy_stream = torch.cuda.Stream()
another_copy_stream = torch.cuda.Stream()
compute_stream = torch.cuda.Stream()
_ = torch.randn(1,1,device="cuda") @ torch.randn(1,1,device="cuda")
torch.cuda.synchronize()
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], on_trace_ready=chrome_trace_handler) as prof:
with torch.cuda.stream(copy_stream):
a_gpu = big_mat_a.to(device=device, non_blocking=True)
b_gpu = big_mat_b.to(device=device, non_blocking=True)
with torch.cuda.stream(another_copy_stream):
# torch.cuda._sleep(int(20 * 1e9)) # sleep ~10s only in copy_stream
c_gpu = big_mat_c.to(device=device, non_blocking=True)
d_gpu = big_mat_d.to(device=device, non_blocking=True)
with torch.cuda.stream(compute_stream):
result_a = torch.matmul(a_gpu, b_gpu)
print(result_a.cpu())
print("done")
My goal is to have the H2D copies for c and d take place in parallel with the matmul computation using a and b. I explicitly placed these operations into different CUDA streams, and all tensors are in pinned memory. However, based on the profiler trace shown below, the operations still appear to run sequentially rather than overlapping.
I would appreciate guidance on:
- Whether this usage of multiple CUDA streams is correct for overlapping transfer and compute.
- If additional synchronization or event dependencies are required.
Thanks in advance for any insight

