I am training my model with MAML (model agnostic meta learning) with torch DDP with nccl backend. I use 8 gpus (a100) on a single machine and I have this problem where the script freezes without giving me any error message when I try to evaluate the model only on the base process (rank==0). The training goes well till evaluation but once one of the processors gets into if iteration % eval_freq == eval_freq - 1
, the script freezes.
Some observations:
- The script works well when I use smaller model (tiny-mbart for debugging) but this hangs when I use mbart-large-50 which has more than 600m parameters.
- On another server with two A6000 gpus, the script works well even with a large model
- The script does not work when I use a large model with 8 A100 gpus on a single machine.
I’d really appreciate any suggestion to solve the problem… Thank you
def meta_train(self,
train_batch,
train_sample_set,
eval_set,
num_steps,
eval_freq,
dataloader_batch_size,
current_run_info,
maml,
optimizer,
scheduler,
scaler,
rank):
if rank== 0:
run_info_writer = open(LOG_DIR / "{}.json".format(current_run_info["id"]), 'w')
atexit.register(log_handle, current_run_info, run_info_writer) # exit handler => log the run info when the program exits
save_best_model = SaveBestModel()
for iteration in tqdm(range(self.initial_step, num_steps + self.initial_step)):
# print(f'iteration: {iteration}, rank: {rank}, param_0: {list(self.maml.parameters())[0]}')
loss = self.outer_step(train_batch, dataloader_batch_size, scaler)
if rank==0:
wandb.log({"step": iteration, "loss": loss})
# adjust gradient by division (task_size)
for p in maml.parameters():
p.grad.data.mul_(1.0 / len(train_batch))
# update model params and update lr
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad()
print("step: {}, loss: {:.4f}".format(iteration, loss))
if iteration % eval_freq == eval_freq - 1: # eval on train sample set
torch.cuda.empty_cache()
print(rank)
if rank==0:
# evaluation on a setaside trainset to monitor learning
train_sample_loss, train_sample_accuracy = self.outer_step_eval(train_sample_set, dataloader_batch_size, iteration, split="train")
print("step: {}, train_sample_loss: {:.4f}, train_sample_accuracy: {:.4f}".format(iteration, train_sample_loss, train_sample_accuracy))
wandb.log({"step": iteration, "train_sample_loss": train_sample_loss, "train_sample_accuracy": train_sample_accuracy})
torch.distributed.barrier()