Simple Distributed Training Example

I apologize, as I am having trouble following the official PyTorch tutorials. I have one system with two GPUs and I would like to use both for training.

The following example is a modification of the following:

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

torch.distributed.init_process_group("nccl", rank=0, world_size=1)
class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")

# create model and move it to GPU with id rank
device_id = rank % torch.cuda.device_count()
model = ToyModel().to(device_id)
ddp_model = DDP(model, device_ids=[device_id])

loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

for i in range(10000):
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_id)
    loss_fn(outputs, labels).backward()

However, I observe that only one GPU is being used. When I try increasing world_size to 2, the system hangs on init_process_group. I see some people recommending the use of init_method='env://', but I am thrown an error unless I specify world_size and rank using os.environ.

What is the correct way to use init_process_group?

Hi Joseph,

You can check this DDP example for how to implement it.

I am planning to complete the DDP example soon and add it to PyTorch/examples repo sometime soon. Please let me know if you think there is something that needs to be improved.

That is correct to set world_size and rank using os.environ if you are using init_method='env://

see this tutorial for init_process_group() Writing Distributed Applications with PyTorch — PyTorch Tutorials 1.12.0+cu102 documentation