distributed.all_gather_object() produces multiple additional processes

Hi, I’m currently studying pytorch DDP with 8 gpus.

I’m trying to train & validate the model with multi-gpus, and the training seems to work fine.

But in the validation phase, I tried to gather the validation output into rank 0 and print the validation accuracy and loss.
It worked, but when dist.all_gather_object is activated, I find that 7 processes are created additionally.

I think there are some inefficiency in my code.

So my questions would be:

  1. Why so much processes are created when I tried to gather values from each gpus?
  2. How to gather results from each gpu to rank 0 properly?
  3. Am I using DDP appropriatly & efficiently?

These are my code for each epoch and GPU status:

    for epoch in range(args.epochs):
        # we have to tell DistributedSampler which epoch this is
        # and guarantees a different shuffling order
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, rank, args)
        val_acc, val_loss = valid(model, val_loader, criterion, rank, args)
        ## gather
        g_acc, g_loss = torch.randn(world_size), torch.randn(world_size)
        dist.all_gather_object(g_acc, val_acc)
        dist.all_gather_object(g_loss, val_loss)
        if rank == 0:
            val_acc, val_loss = g_acc.mean(), g_loss.mean()
            print(f"EPOCH {epoch} VALID: acc = {val_acc}, loss = {val_loss}")
            if val_acc > best_acc:
                    "epoch": epoch+1,
                    "state_dict": model.module.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                }, file_name=os.path.join(args.exp, f"best_acc.pth"))
            if val_loss < best_loss:
                    "epoch": epoch+1,
                    "state_dict": model.module.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                }, file_name=os.path.join(args.exp, f"best_loss.pth"))
                    "epoch": epoch+1,
                    "state_dict": model.module.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                }, file_name=os.path.join(args.exp, f"last.pth"))

Thanks for read!

@Taejune how do you initialize the process_group? all_gather_object itself won’t spawn new processes, it’s mostly user need to create multiple processes with the launcher. could you please check why there’re multiple new processes?

Hi, thanks for your reply and sorry for my late reply.
I called dist.init_process_group("nccl", "env://", rank=rank, world_size=8).
I used the world size of 8 instead of 1 in the code below.
And I used mp.spawn method to create multiple processes to achieve DDP.

Would mp.spawn be a reason for creating such additional processes?


def setup(rank, world_size, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", init_method=args.dist_url,rank=rank, world_size=world_size)
def parse_args():
    parser = argparse.ArgumentParser(description="Imagenet Training")
    ## Config
    parser.add_argument("--exp", type=str, default="./exp/default")
    ## DDP
    parser.add_argument("--dist_url", type=str, default="env://")
    ## training
    parser.add_argument("--data_path", type=str, default="/home/data/imagenet")
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--valid-iter", type=int, default=1000)
    ## data loader
    parser.add_argument("--pin-memory", action='store_true')
    parser.add_argument("--num-workers", type=int, default=2) # may cause a bottleneck if set to be 0
    parser.add_argument("--drop-last", action="store_true")
    parser.add_argument("--shuffle", action="store_true")
    return parser.parse_args()

if __name__ == "__main__":
    world_size = 1
    args = parse_args()

    os.makedirs(os.path.join("./exp", args.exp), exist_ok=True)
        args=(world_size, args),