Running Two Batches in Parallel Using CUDA Streams Does Not Overlap During Training

I have a model trained using DDP (Distributed Data Parallel), and the communication overhead between GPUs is quite high. Since the computation depends on the transferred data, I cannot overlap communication and computation directly. To address this, I attempted to use CUDA streams to run two batches in parallel, so that the computation time of one batch can overlap with the other. However, the profiler results show that the two forward passes are not overlapping. Does anyone know why this might be happening?

    def forward(self, batch_dict_odd, batch_dict_even=None):
        if self.export_mode:
            return self.forward_export(batch_dict)

                  batch_dict_even = self.raw_encoder(batch_dict_even)
                  batch_dict_even = self.pillar_scatter(batch_dict_even)
                  batch_dict_even = self.lidar_bev_backbone(batch_dict_even)
                  batch_dict_even = self.map_encoder(batch_dict_even)
                  batch_dict_even = self.route_encoder(batch_dict_even)
                  batch_dict_even = self.e2e_ego_decoder(batch_dict_even)
                loss_even, tb_dict_even, disp_dict_even = self.get_loss()
              batch_dict_odd = self.raw_encoder(batch_dict_odd)
              batch_dict_odd = self.pillar_scatter(batch_dict_odd)
              batch_dict_odd = self.lidar_bev_backbone(batch_dict_odd)
              batch_dict_odd = self.map_encoder(batch_dict_odd)
              batch_dict_odd = self.route_encoder(batch_dict_odd)
              batch_dict_odd = self.e2e_ego_decoder(batch_dict_odd)

                loss_odd, tb_dict_odd, disp_dict_odd = self.get_loss()
            loss, tb_dict, disp_dict = self.merge_result(loss_odd, tb_dict_odd, disp_dict_odd, loss_even, tb_dict_even, disp_dict_even)
            return loss, tb_dict, disp_dict```

This post shows how to overlap data transfer and computation. To overlap compute kernels your GPU must have enough free resources. E.g. if the kernel running on the first stream is using all SMs, the other kernels have to wait.

Thank you so much for your help!

I created a demo, and it successfully overlaps computation with all_gather communication. However, when I apply the same approach to my actual model, it doesn’t work as expected. I’m currently trying to debug and figure out the root cause.

By the way, the demo code is

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from import DataLoader, DistributedSampler
from import Dataset
from torch.profiler import profile, record_function, ProfilerActivity

# Define the model
class streamModel(torch.nn.Module):
    def __init__(self):
        super(streamModel, self).__init__()
        self.fc1 = torch.nn.Linear(10, 20)  # Input dimension is 10
        self.fc2 = torch.nn.Linear(20, 2)
        self.stream_odd = torch.cuda.Stream()  # Custom stream for x_odd
        self.stream_even = torch.cuda.Stream()  # Custom stream for x_even

    def forward(self, x_odd, x_even):
        # Ensure inputs are on GPU
        x_odd = x_odd.cuda(non_blocking=True)
        x_even = x_even.cuda(non_blocking=True)

        # Process x_odd in custom stream
            x_odd = torch.relu(self.fc1(x_odd))

            # Use all_gather to collect x_odd from all processes
            gathered_x_odd = [torch.zeros_like(x_odd) for _ in range(dist.get_world_size())]
            dist.all_gather(gathered_x_odd, x_odd)

            # Concatenate gathered_x_odd into a single tensor
            gathered_x_odd =, dim=0)

            # Use the result after all_gather
            x_odd = self.fc2(gathered_x_odd)

        # Process x_even in custom stream
            x_even = torch.relu(self.fc1(x_even))

            # Use all_gather to collect x_even from all processes
            gathered_x_even = [torch.zeros_like(x_even) for _ in range(dist.get_world_size())]
            dist.all_gather(gathered_x_even, x_even)

            # Concatenate gathered_x_even into a single tensor
            gathered_x_even =, dim=0)

            # Use the result after all_gather
            x_even = self.fc2(gathered_x_even)

        # Synchronize custom streams

        # Merge outputs
        output =[x_odd, x_even], dim=1)
        return output

# Define a simple dataset
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length = torch.randn(length, size)

    def __getitem__(self, index):

    def __len__(self):
        return self.len

# Initialize the distributed environment
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

# Clean up the distributed environment
def cleanup():

# Training function
def train(rank, world_size):
    print(f"Running DDP on rank {rank}.")
    setup(rank, world_size)

    # Create model and wrap it with DDP
    model = streamModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True)  # Enable find_unused_parameters

    # Create dataset and data loader
    dataset = RandomDataset(size=10, length=100)  # Input dimension is 10
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=20, sampler=sampler)

    # Define loss function and optimizer
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)

    # Initialize Profiler
    prof = profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],  # Record CPU and GPU activities
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),  # Configure Profiler schedule
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),  # Save logs to ./log directory
        record_shapes=True,  # Record tensor shapes
        profile_memory=True,  # Record memory usage
        with_stack=True  # Record call stack

    # Training loop
    prof.start()  # Start Profiler
    for epoch in range(2):  # 2 epochs
        sampler.set_epoch(epoch)  # Set epoch to shuffle data
        for step, batch in enumerate(dataloader):

            # Split data into x_odd and x_even
            x_odd = batch  # All data as x_odd
            x_even = batch  # All data as x_even

            # Forward pass
            with record_function("forward"):  # Record forward pass time
                output = ddp_model(x_odd, x_even)

            # Compute loss
            with record_function("loss_computation"):  # Record loss computation time
                target = torch.randn_like(output)  # Random target
                loss = criterion(output, target)

            # Backward pass
            with record_function("backward"):  # Record backward pass time

            # Optimizer step
            with record_function("optimizer_step"):  # Record optimizer step time

            prof.step()  # Profiler records current step

        print(f"Rank {rank}, Epoch {epoch}, Loss: {loss.item()}")

    prof.stop()  # Stop Profiler

# Launch multi-process training
def run_demo(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

# Run the test
if __name__ == "__main__":
    world_size = 2  # Use 2 GPUs