Understanding torch.distributed.DistributedDataParallel from a torchvision's reference example

I would like to ask some questions regarding the DDP code used in the torchvision's reference example on classification. An example of using this script is given as follows, on a machine with 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --model resnext50_32x4d --epochs 100

My first question concerns the saving and loading of checkpoints.

This is how a checkpoint is saved in the script:

checkpoint = {
    'model': model_without_ddp.state_dict(),
    'optimizer': optimizer.state_dict(),
    'lr_scheduler': lr_scheduler.state_dict(),
    'epoch': epoch,
    'args': args}
utils.save_on_master(
    checkpoint,
    os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
    checkpoint,
    os.path.join(args.output_dir, 'checkpoint.pth'))

But in the DDP tutorial, it seems necessary that torch.distributed.barrier() is called somewhere:

# Use a barrier() to make sure that process 1 loads the model after process 0 saves it.
dist.barrier()
...
# Use a barrier() to make sure that all processes have finished reading the checkpoint
dist.barrier()

Why is dist.barrier() not necessary in the above reference example?

My second question is about the validation stage.

This is how it’s done in the script:

for epoch in range(args.start_epoch, args.epochs):
    if args.distributed:
        train_sampler.set_epoch(epoch)
    train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
    lr_scheduler.step()
    evaluate(model, criterion, data_loader_test, device=device)

Doesn’t this mean that the evaluate() function is called on all the processes (i.e. all the GPUs in this case)? Shouldn’t we rather do something like this:

for epoch in range(args.start_epoch, args.epochs):
    if args.distributed:
        train_sampler.set_epoch(epoch)
    train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
    lr_scheduler.step()
    if torch.distributed.get_rank() == 0: # master
        evaluate(model, criterion, data_loader_test, device=device)
        # save checkpoint here as well

But then, again, shouldn’t we wait, using dist.barrier(), for all the processes to finish the computations and for the master to gather the gradients, before evaluating the model?

Thank you very much in advance for your help!

Why is dist.barrier() not necessary in the above reference example?

IIUC, the torchvision example only saves the checkpoint to file in epoch but is not reading from it unless it is recovering from a crash? In that case it is not a hard requirement to perform the barrier, because

  1. it does not need to ensure non-master processes are reading stale checkpoints.
  2. non-master processes will just block on DDP backward (AllReduce) and waiting for the master to join.

Doesn’t this mean that the evaluate() function is called on all the processes (i.e. all the GPUs in this case)?

cc @fmassa for this implementation

But then, again, shouldn’t we wait, using dist.barrier() , for all the processes to finish the computations and for the master to gather the gradients, before evaluating the model?

The gradients are synchronized in DDP backward using AllReduce operations. So, there is no need to add another barrier here to do that. As soon as loss.backward() returns, the local gradients should be representing the global average. However, it might need a barrier here for a different reason. If the evaluate() step takes too long, non-master processes could timeout on AllReduce. If that happens, barrier might help.

@mrshenli Thanks for your prompt reply!

IIUC, the torchvision example only saves the checkpoint to file in epoch but is not reading from it unless it is recovering from a crash?

Well, the script has a resume code as follows:

if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    model_without_ddp.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    args.start_epoch = checkpoint['epoch'] + 1

so I guess it does what you’ve described, which is a usual scenario. But then are you saying that the demo_checkpoint() example given in the tutorial handles a different scenario (other than resuming a training)?

The gradients are synchronized in DDP backward using AllReduce operations. So, there is no need to add another barrier here to do that. As soon as loss.backward() returns, the local gradients should be representing the global average.

Thanks! This has cleared up a lot of things for me. I guess the gradients are averaged? In this case, it seems that the learning rate should be scaled up by the number of GPUs.

Besides, the tutorial also notes that

if training starts from random parameters, you might want to make sure that all DDP processes use the same initial values. Otherwise, global gradient synchronizes will not make sense.

but I don’t see this being taken into account anywhere in the reference script.

Well, the script has a resume code as follows:

Not 100% sure about how this example would be used. @fmassa would know more. Given the code, it looks the resume mode is designed for starting from pre-trained models or resume from crash. In these cases, the checkpoint file is ready before launching the script, so it should be fine.

you saying that the demo_checkpoint() example given in the tutorial handles a different scenario

It is not targeting any specific use case, just wanted to make sure the example code can run as is. The main information it tries to convey is that applications need to make sure checkpoints are ready before loading them. We previous saw users running into weird errors caused by reading too soon.

if training starts from random parameters, you might want to make sure that all DDP processes use the same initial values. Otherwise, global gradient synchronizes will not make sense.

DDP handles this by broadcasting model weights from rank 0 to others at construction time. However, if the application modified model weights after constructing DDP and if that resulted in inconsistent weights across processes, DDP won’t be able to recover, as the broadcast only happens once in ctor.

@mrshenli Great. Thank you so much for the explanations!
I hope @fmassa could join the discussion and clarify the points for which you mentioned him earlier, especially the one related to evaluate(). I tested the code and it seems that this function is called only once across the processes.

The training loop looks like this:

for epoch in range(args.start_epoch, args.epochs):
    if args.distributed:
        train_sampler.set_epoch(epoch)
    train_one_epoch(...)
    lr_scheduler.step()
    evaluate(model, criterion, data_loader_test, device=device)

where the evaluation function looks like:

def evaluate(model, criterion, data_loader, device, print_freq=100):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    with torch.no_grad():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            ...
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

    print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5))
    return metric_logger.acc1.global_avg

This function has some print() to display the accuracy. In my experiment, that string is only displayed once, which means the function is called only once. Why? There is no is_main_process() check. Why isn’t this function called on all processes? I’m confused…

I’ve just realized that we shouldn’t wrap the evaluation phase inside if torch.distributed.get_rank() == 0:, because data_loader_test also splits the data across all the processes.

And for the print() part, this code explains why the message is displayed only once.

So now I understand what @fmassa did. Thanks.