parser.add_argument('--dist_url', default="env://", type=str)
parser.add_argument('--rank', type=int)
parser.add_argument('--gpu_to_work_on', type=int)
params = parser.parse_args()
def example():
from torch.nn.parallel import DistributedDataParallel as DDP
init_dist(args)
model = nn.Linear(10, 10).cuda(params.gpu_to_work_on)
ddp_model = DDP(model, device_ids=[params.gpu_to_work_on])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
outputs = ddp_model(torch.randn(20, 10).cuda(params.gpu_to_work_on))
labels = torch.randn(20, 10).cuda(params.gpu_to_work_on)
loss_fn(outputs, labels).backward()
optimizer.step()
def init_dist(params):
params.rank = int(os.environ["RANK"])
params.world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group(
backend="nccl",
init_method=params.dist_url,
world_size=params.world_size,
rank=params.rank,
)
params.gpu_to_work_on = params.rank % torch.cuda.device_count()
print('rank:', params.rank)
print('gpu_to_work_on:', params.gpu_to_work_on)
print('n_gpus:', torch.cuda.device_count())
torch.cuda.set_device(params.gpu_to_work_on)
return
if __name__ == '__main__':
example()
python -m torch.distributed.launch main.py
Am I correct using DistributedDataParallel
?