Here is an exhaustive list of available PyTorch communication primitives. For a sample code, I took the liberty in adapting your excerpt (the idea is identical for your test script)
for epoch in range(1, num_epochs + 1):
# Initialize metric for metric computation, for each epoch-
running_loss = 0.0
running_corrects = 0.0
model.train()
train_loader.sampler.set_epoch(epoch)
# One epoch of training-
for batch_idx, (images, labels) in enumerate(train_loader):
images = images.to(rank)
labels = labels.to(rank)
# Get model predictions-
outputs = model(images)
# Compute loss-
J = loss(outputs, labels)
# Empty accumulated gradients-
optimizer.zero_grad()
# Perform backprop-
J.backward()
# Update parameters-
optimizer.step()
'''
global step
optimizer.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
step += 1
'''
# Compute model's performance statistics on each rank
running_loss += J.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
running_corrects += torch.sum(predicted == labels.data)
# to globally reduce local metrics across ranks, they should be Tensors
running_loss = torch.tensor([running_loss], device=self.rank)
running_corrects = torch.tensor([running_corrects], device=self.rank)
if torch.cuda.is_available():
reduce(running_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
reduce(running_corrects, dst=0, op=torch.distributed.ReduceOp.SUM)
# will log the aggregated metrics only on the 0th GPU. Make sure "train_dataset" is of type Dataset and not DataLoader to get the size of the full dataset and not of the local shard
if rank==0:
train_loss = running_loss / len(train_dataset)
train_acc = (running_corrects.double() / len(train_dataset)) * 100
print(f"GPU: {rank}, epoch = {epoch}; train loss = {train_loss:.4f} & train accuracy = {train_acc:.2f}%")