Memory consumption for the model get doubled after wrapped with DDP

Hi.

I’m trying to use DistributedDataParallel in a single GPU node, for practice.

I checked my model’s size right after initializing it by
sum([p.numel()*4 for p in model.parameters()])/1024/1024, and it says 1140.xxx

and then I wrapped the model with DistributedDataParallel by
dmodel =DistributedDataParallel(module=model, ...).

then I checkout memory allocated by,
torch.cuda.memory_allocated()/1024/1024, and it says 2281.xx which is almost double of my model size.

I thought it’s due to using variable dmodel different from model.
So I retried to wrapping the model with DistributedDataParallel without using variable dmodel.
model = DistributedDataParallel(module=model, ...).

but torch.cuda.memory_allocated()/1024/1024 still says 2281.xx.

Is it expected behavior that memory consumption for model is doubled when using DDP?

1 Like

Could you show how you are using DDP on a single device, please?
Based on the memory usage you are seeing, I would guess you might be creating two processes on the same device, which would then also initialize two models.

Thank you for your endless help, @ptrblck .

I’m referencing the way of using DDP from this repository

And just in case, let me know you that:

  • my model contains BiLSTM
  • I’m also using a quantizer (but, skipped in the abstract codes below).

Abstract of my work’s __main__.py.

The order of instance initialization is identical with my source code.

import torch
from torch import distributed
from torch.nn.parallel import DistributedDataParallel

def train(args):

    device = torch.device(f'cuda:{args.rank}')
    torch.cuda.set_device(device=device)
    distributed.init_process_group(
        backend='nccl',
        init_method=f'tcp://{args.master_url}',
        world_size=args.world_size,
        rank=args.rank,
    )

    # Instantiate my torch.utils.data.Dataset object
    train_dataset = MyDataset()

    # Instantiate my model
    model = MyModule()
    model.to(device)

    # Augmentation modules which has no parameters.
    augmentations = [
        Augmentation1(),
        Augmentation2(),
        Augmentation3(),
    ]
    augmentations = torch.nn.Sequential(*augmentations)
    augmentations.to(device)

    # And instantiate etc.
    optimizer = ...
    criterion = ...

   # I've checked the memory usage here, and it says 1140.xx MiB.

    # Wrap the model with DDP.
    model = DistributedDataParallel(
        module=model,
        device_ids=[torch.cuda.current_device()],
        output_device=torch.cuda.current_device(),
    )

    # I've checked the memory usage here again, and it says 2281.xx MiB.

    # Instantiate my Trainer class, whose abstract is below.
    trainer = Trainer(
        model=model,
        dataset=train_dataset,
        criterion=criterion,
        optimizer=optimizer,
        batch_size=batch_size,
        num_workers=args.num_workers,
        device=device,
        world_size=args.world_size,
    )

    for epoch in range(epoch):
        for metrics in trainer.train(epoch) # train 1 epoch
            print(metrics)
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--config_file", type=str)
    parser.add_argument("--epochs", type=int)
    parser.add_argument("--num_workers", type=int)
    parser.add_argument("--rank", type=int, default=-1)
    parser.add_argument("--world_size", type=int, default=1)
    parser.add_argument("--master_url", type=str)

    args = parser.parse_args()
    train(args)

Abstracts of my Trainer class

    class Trainer:
        def __init__(self, [GIVEN ARGS]):
            self.model = model
            self.dataset = dataset
            self.augmentations = augmentations
            self.criterion = criterion
            self.num_workers = num_workers // world_size
            self.device = device
            self.world_size = world_size
            self.batch_size = batch_size // world_size
            
            self.sampler = DistributedSampler(
                dataset=dataset,
                shuffle=dataset.is_trainset(),
            )

            self.dataloader = DataLoader(
                dataset=self.dataset,
                batch_size=self.batch_size if dataset.is_trainset() else 1,
                num_workers=self.num_workers,
                pin_memory=True,
                drop_last=dataset.is_trainset() is False,
                sampler=self.sampler
            )
    
    def train(epoch):
        self.model.train()

        self.sampler.set_epoch(epoch)

        for x, y in self.dataloader:
            self.optimizer.zero_grad()
            x = x.to(self.device, non_blocking=True)
            y = y.to(self.device, non_blocking=True)

            x = self.augmentations(x)

            y_hat = self.model(x)

            cost = self.criterion(input=y, target=y_hat)
            cost.backward()
            self.optimizer.step()

            del x, y, y_hat

            yield cost.item()

run.py for running processes.

The script is quite identical with that of the referenced repository’s


def main():

    args = sys.argv[1:] # arguments for __main__.py

    gpus = torch.cuda.device_count() # supposed to be 1 in my case.
    free_port = get_free_port()
    master_url = f'127.0.0.1:{free_port}'

    args += ["--world_size", str(gpus), "--master_url", f"127.0.0.1:{port}"]

    tasks = []
    for gpu in range(gpus):
        if gpu > 0:
            tasks.append(sp.Popen(["python3", "-m", "my_model"] + args + ["--rank", str(gpu)]))
        tasks[-1].rank = gpu
    
    while tasks:
            for task in tasks:
                try:
                    exitcode = task.wait(0.1)
                except sp.TimeoutExpired:
                    continue
                else:
                    tasks.remove(task)
                    if exitcode:
                        print(f"Task {task.rank} died with exit code "
                              f"{exitcode}",
                              file=sys.stderr)
                        failed = True
            if failed:
                break
    if failed:
        for task in tasks:
            task.terminate()
        sys.exit(1)


if __name__ == "__main__":
    main()

Thanks again.

I’ve debugged my script line by line, and found that the allocated memory get doubled when torch.distributed.Reducer is instantiated in the constructor of DistributedDataParallel.

I think the reducer is a necessary component for DDP, because it sums up the result from all the device.
But I don’t know how the reducer works, so that I still can’t understand why the memory gets doubled.

  1. Is it expected behavior that the reducer takes the additional memory as the local model takes?
  2. Does the reducer take the addition memory only for the rank:0 device?
    I mean the addition memory consumption would not occur in the rank:1 or rank:2??
    I can’t check this because I have only one gpu.
1 Like

That’s great debugging!
I’ve checked the behavior with @mcarilli and he confirms that the Reducer will create gradient buckets for each parameter, so that the memory usage after wrapping the model into DDP will be 2 x model_parameter_size. Note that the parameter size of a model is often much smaller than the activation size so that this memory increase might or might not be significant.

3 Likes

DDP maintains one buffer that is the same size as model size in default, so it is expected the memory is double of model size. If you set ‘gradients_as_bucket_view=True’, the peak memory allocation will be reduced around one copy of model size

5 Likes