Hi,
I’m using DistributedDataParallel to train a simple classification model. I have some experience with distributed training, but I can’t seem to wrap my head around one specific detail.
Let me refer you to an example provided by PyTorch: examples/main.py at master · pytorch/examples · GitHub
Here, you will see that the accuracy is calculated by a accuracy()
function, and the average accuracy is updated using the AverageMeter
in the following lines.
From my understanding, this calculates the accuracy for the samples that each GPU receives, not the accuracy of samples across all GPUs. So, this function returns top1.avg
, and now we head over to L248, where the model is saved if the accuracy from GPU rank 0 is larger than the best accuracy.
Am I going crazy, or is this intended behavior? Are we assuming that all GPUs receive the same samples, or that the accuracy on GPU0 is somehow representative of entire accuracy?
To show that my interpretation is correct, I wrote a short sandbox code that mimics the code attached above. The AverageMeter
and accuracy()
functions were all copied from the linked code base. This assumes a 2-class classification scenario, with batch_size=5
, and I ran it on 4 GPUs:
acc_meter = AverageMeter()
model = nn.Linear(10, 2).cuda(gpu)
model = DistributedDataParallel(model, device_ids=[gpu])
a = torch.randn(5, 10).cuda(gpu)
gt = torch.randint(0, 2, (5, 1)).cuda(gpu)
outputs = model(a)
acc = accuracy(outputs, gt, topk=(1,))
acc_meter.update(acc[0], a.size(0))
print("Avg: ", acc_meter.avg)
print("Sum: ", acc_meter.sum)
print("Count: ", acc_meter.count)
return acc_meter.avg
There are two issues:
- As suspected, returning
acc.avg
will only return the accuracy for current GPU. This means that saving a checkpoint or logging fromrank=0
will only checkpoint or log the accuracy fromrank=0
. - The accuracy calculation is wrong. The
accuracy()
function divides by thebatch_size
, so returningacc_meter.avg
divides by thebatch_size
again. The return value should beacc_meter.sum
.
Ultimately, I would like to write code that uses DistributedDataParallel
but can compute the accuracy correctly. For now, I have resorted to the following method:
- Compute
num_correct
for all GPUs - all reduce
num_correct
, as well asnum_samples
as such:dist.all_reduce(num_correct), dist.all_reduce(num_samples)
. For this step,num_samples
must be cast to GPU first. - Cast back CPU, then update the average meter.
To me, this does not seem like an elegant solution. Perhaps this could mean creating an AverageMeter
that can handle Distributed updates? In search of a more elegant solution, I’ve looked at multiple code bases, but they all seem to do it incorrectly, as shown above. Am I completely missing something big here? If anyone has suggestions/solutions, I would love to hear from you.