DDP hangs for evaluation without any error message

I am training my model with MAML (model agnostic meta learning) with torch DDP with nccl backend. I use 8 gpus (a100) on a single machine and I have this problem where the script freezes without giving me any error message when I try to evaluate the model only on the base process (rank==0). The training goes well till evaluation but once one of the processors gets into if iteration % eval_freq == eval_freq - 1, the script freezes.

Some observations:

  1. The script works well when I use smaller model (tiny-mbart for debugging) but this hangs when I use mbart-large-50 which has more than 600m parameters.
  2. On another server with two A6000 gpus, the script works well even with a large model
  3. The script does not work when I use a large model with 8 A100 gpus on a single machine.

I’d really appreciate any suggestion to solve the problem… Thank you

    def meta_train(self,

        if rank== 0:
            run_info_writer = open(LOG_DIR / "{}.json".format(current_run_info["id"]), 'w')
            atexit.register(log_handle, current_run_info, run_info_writer) # exit handler => log the run info when the program exits
        save_best_model = SaveBestModel()

        for iteration in tqdm(range(self.initial_step, num_steps + self.initial_step)):
            # print(f'iteration: {iteration}, rank: {rank}, param_0: {list(self.maml.parameters())[0]}')
            loss = self.outer_step(train_batch, dataloader_batch_size, scaler)
            if rank==0:
                wandb.log({"step": iteration, "loss": loss})

            # adjust gradient by division (task_size)
            for p in maml.parameters():
                p.grad.data.mul_(1.0 / len(train_batch))

            # update model params and update lr
            print("step: {}, loss: {:.4f}".format(iteration, loss))

            if iteration % eval_freq == eval_freq - 1: # eval on train sample set
                if rank==0:
                    # evaluation on a setaside trainset to monitor learning
                    train_sample_loss, train_sample_accuracy = self.outer_step_eval(train_sample_set, dataloader_batch_size, iteration, split="train")
                    print("step: {}, train_sample_loss: {:.4f}, train_sample_accuracy: {:.4f}".format(iteration, train_sample_loss, train_sample_accuracy))
                    wandb.log({"step": iteration, "train_sample_loss": train_sample_loss, "train_sample_accuracy": train_sample_accuracy})

when wrapping a module with DDP, DDP would install hooks to the module, which might potentially delay some allreduce calls during the backward and optimizers. I am not sure if that would be the root cause of your problem but it might affects how the model behave if you run eval inside a training loop.

Thank you very much for you answer @wanchaol. I based my code on this code here so I didn’t use DDP wrapper to wrap my model and used torchrun command. The weird thing is even when I replace the evaluation part to time.sleep(), it hangs as soon as any running process gets into the barrier… I printed after each line to trace the problem and as soon as any of the running process gets into the barrier, it freezes the rest of the process (specifically before optimizer.step()) without printing any error message…

-- model training -- 

if iteration % eval_freq == 0 :
    if rank == 0:
    torch.distributed.barrier()      # all ranks wait here till rank 0 finishes 

I think I found where the problem comes from… I was using torch.cuda.amp.GradScaler() for mixed precision. I found that when it hangs, it’s always before scaler.step(optimizer). I replaced it with optimizer.step() (no mixed precision now) just like that and now it keeps running till the end. I have no idea why it caused the problem but at least now my script runs without hanging…