By the way, the demo code is
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import Dataset
from torch.profiler import profile, record_function, ProfilerActivity
# Define the model
class streamModel(torch.nn.Module):
def __init__(self):
super(streamModel, self).__init__()
self.fc1 = torch.nn.Linear(10, 20) # Input dimension is 10
self.fc2 = torch.nn.Linear(20, 2)
self.stream_odd = torch.cuda.Stream() # Custom stream for x_odd
self.stream_even = torch.cuda.Stream() # Custom stream for x_even
def forward(self, x_odd, x_even):
# Ensure inputs are on GPU
x_odd = x_odd.cuda(non_blocking=True)
x_even = x_even.cuda(non_blocking=True)
# Process x_odd in custom stream
with torch.cuda.stream(self.stream_odd):
x_odd = torch.relu(self.fc1(x_odd))
# Use all_gather to collect x_odd from all processes
gathered_x_odd = [torch.zeros_like(x_odd) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_x_odd, x_odd)
# Concatenate gathered_x_odd into a single tensor
gathered_x_odd = torch.cat(gathered_x_odd, dim=0)
# Use the result after all_gather
x_odd = self.fc2(gathered_x_odd)
# Process x_even in custom stream
with torch.cuda.stream(self.stream_even):
x_even = torch.relu(self.fc1(x_even))
# Use all_gather to collect x_even from all processes
gathered_x_even = [torch.zeros_like(x_even) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_x_even, x_even)
# Concatenate gathered_x_even into a single tensor
gathered_x_even = torch.cat(gathered_x_even, dim=0)
# Use the result after all_gather
x_even = self.fc2(gathered_x_even)
# Synchronize custom streams
torch.cuda.synchronize(self.stream_odd)
torch.cuda.synchronize(self.stream_even)
# Merge outputs
output = torch.cat([x_odd, x_even], dim=1)
return output
# Define a simple dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
# Initialize the distributed environment
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# Clean up the distributed environment
def cleanup():
dist.destroy_process_group()
# Training function
def train(rank, world_size):
print(f"Running DDP on rank {rank}.")
setup(rank, world_size)
# Create model and wrap it with DDP
model = streamModel().to(rank)
ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=True) # Enable find_unused_parameters
# Create dataset and data loader
dataset = RandomDataset(size=10, length=100) # Input dimension is 10
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=20, sampler=sampler)
# Define loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001)
# Initialize Profiler
prof = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], # Record CPU and GPU activities
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3), # Configure Profiler schedule
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), # Save logs to ./log directory
record_shapes=True, # Record tensor shapes
profile_memory=True, # Record memory usage
with_stack=True # Record call stack
)
# Training loop
prof.start() # Start Profiler
for epoch in range(2): # 2 epochs
sampler.set_epoch(epoch) # Set epoch to shuffle data
for step, batch in enumerate(dataloader):
optimizer.zero_grad()
# Split data into x_odd and x_even
x_odd = batch # All data as x_odd
x_even = batch # All data as x_even
# Forward pass
with record_function("forward"): # Record forward pass time
output = ddp_model(x_odd, x_even)
# Compute loss
with record_function("loss_computation"): # Record loss computation time
target = torch.randn_like(output) # Random target
loss = criterion(output, target)
# Backward pass
with record_function("backward"): # Record backward pass time
loss.backward()
# Optimizer step
with record_function("optimizer_step"): # Record optimizer step time
optimizer.step()
prof.step() # Profiler records current step
print(f"Rank {rank}, Epoch {epoch}, Loss: {loss.item()}")
prof.stop() # Stop Profiler
cleanup()
# Launch multi-process training
def run_demo(world_size):
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
# Run the test
if __name__ == "__main__":
world_size = 2 # Use 2 GPUs
run_demo(world_size)