Memory consumption for the model get doubled after wrapped with DDP

Thank you for your endless help, @ptrblck .

I’m referencing the way of using DDP from this repository

And just in case, let me know you that:

  • my model contains BiLSTM
  • I’m also using a quantizer (but, skipped in the abstract codes below).

Abstract of my work’s __main__.py.

The order of instance initialization is identical with my source code.

import torch
from torch import distributed
from torch.nn.parallel import DistributedDataParallel

def train(args):

    device = torch.device(f'cuda:{args.rank}')
    torch.cuda.set_device(device=device)
    distributed.init_process_group(
        backend='nccl',
        init_method=f'tcp://{args.master_url}',
        world_size=args.world_size,
        rank=args.rank,
    )

    # Instantiate my torch.utils.data.Dataset object
    train_dataset = MyDataset()

    # Instantiate my model
    model = MyModule()
    model.to(device)

    # Augmentation modules which has no parameters.
    augmentations = [
        Augmentation1(),
        Augmentation2(),
        Augmentation3(),
    ]
    augmentations = torch.nn.Sequential(*augmentations)
    augmentations.to(device)

    # And instantiate etc.
    optimizer = ...
    criterion = ...

   # I've checked the memory usage here, and it says 1140.xx MiB.

    # Wrap the model with DDP.
    model = DistributedDataParallel(
        module=model,
        device_ids=[torch.cuda.current_device()],
        output_device=torch.cuda.current_device(),
    )

    # I've checked the memory usage here again, and it says 2281.xx MiB.

    # Instantiate my Trainer class, whose abstract is below.
    trainer = Trainer(
        model=model,
        dataset=train_dataset,
        criterion=criterion,
        optimizer=optimizer,
        batch_size=batch_size,
        num_workers=args.num_workers,
        device=device,
        world_size=args.world_size,
    )

    for epoch in range(epoch):
        for metrics in trainer.train(epoch) # train 1 epoch
            print(metrics)
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--config_file", type=str)
    parser.add_argument("--epochs", type=int)
    parser.add_argument("--num_workers", type=int)
    parser.add_argument("--rank", type=int, default=-1)
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument("--master_url", type=str)

    args = parser.parse_args()
    train(args)

Abstracts of my Trainer class

    class Trainer:
        def __init__(self, [GIVEN ARGS]):
            self.model = model
            self.dataset = dataset
            self.augmentations = augmentations
            self.criterion = criterion
            self.num_workers = num_workers // world_size
            self.device = device
            self.world_size = world_size
            self.batch_size = batch_size // world_size
            
            self.sampler = DistributedSampler(
                dataset=dataset,
                shuffle=dataset.is_trainset(),
            )

            self.dataloader = DataLoader(
                dataset=self.dataset,
                batch_size=self.batch_size if dataset.is_trainset() else 1,
                num_workers=self.num_workers,
                pin_memory=True,
                drop_last=dataset.is_trainset() is False,
                sampler=self.sampler
            )
    
    def train(epoch):
        self.model.train()

        self.sampler.set_epoch(epoch)

        for x, y in self.dataloader:
            self.optimizer.zero_grad()
            x = x.to(self.device, non_blocking=True)
            y = y.to(self.device, non_blocking=True)

            x = self.augmentations(x)

            y_hat = self.model(x)

            cost = self.criterion(input=y, target=y_hat)
            cost.backward()
            self.optimizer.step()

            del x, y, y_hat

            yield cost.item()

run.py for running processes.

The script is quite identical with that of the referenced repository’s


def main():

    args = sys.argv[1:] # arguments for __main__.py

    gpus = torch.cuda.device_count() # supposed to be 1 in my case.
    free_port = get_free_port()
    master_url = f'127.0.0.1:{free_port}'

    args += ["--world_size", str(gpus), "--master_url", f"127.0.0.1:{port}"]

    tasks = []
    for gpu in range(gpus):
        if gpu > 0:
            tasks.append(sp.Popen(["python3", "-m", "my_model"] + args + ["--rank", str(gpu)]))
        tasks[-1].rank = gpu
    
    while tasks:
            for task in tasks:
                try:
                    exitcode = task.wait(0.1)
                except sp.TimeoutExpired:
                    continue
                else:
                    tasks.remove(task)
                    if exitcode:
                        print(f"Task {task.rank} died with exit code "
                              f"{exitcode}",
                              file=sys.stderr)
                        failed = True
            if failed:
                break
    if failed:
        for task in tasks:
            task.terminate()
        sys.exit(1)


if __name__ == "__main__":
    main()

Thanks again.