When I run below code with world size=1, everything runs fine but when I run with world size>1, model = DDP(model, device_ids=[rank]) isn’t executed.
Here is the code I’m running:
import os
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from torchvision.datasets import MNIST
from torchvision import transforms
bsize = 5000
def seed_everything(seed):
import random
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = torch.nn.Linear(10, 10).to(device)
model = DDP(model, device_ids=[rank])
dataset = MNIST(root='./data', download=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]), train=False)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, drop_last=True)
dataloader = DataLoader(dataset, batch_size=bsize, sampler=sampler)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(5):
sampler.set_epoch(epoch)
epoch_loss = 0.0
for inputs, targets in tqdm(dataloader, desc=f"Rank {rank}, Epoch {epoch}"):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
inputs = inputs.view(inputs.size(0), -1)
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Rank {rank}, Epoch {epoch}, Loss: {epoch_loss:.4f}")
cleanup()
if __name__ == "__main__":
seed_everything(1)
world_size = torch.cuda.device_count()
seed = 42 # Set a fixed seed for reproducibility
mp.spawn(train, args=(world_size, seed), nprocs=world_size, join=True)