Why does DistributedDataParallel consume more GPU memory compared to DataParallel during AMP training?

When I use DataParallel(), the maximum batch size can be set to 512(cudnn.benchmark is disabled.), but DistributedDataParallel only supports setting batchSize to 128.

Could cudnn cause such a problem?

This is the main structure of my code.

if __name__ == '__main__':
    ...
    # Multi GPU
    print(f'Running DDP on rank: {args.local_rank}')
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')
    main()

def main():
    ...
    train_sampler = DistributedSampler(train_dataset)
        same_seeds(args.seed)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.num_workers,
        # worker_init_fn=_init_fn,
        pin_memory=True,
        sampler=train_sampler
        )
       ...
       # Creates a GradScaler once at the beginning of training.
       scaler = GradScaler()

       # Distribute model across all visible GPUs
       net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
       net = DDP(net, device_ids=[args.local_rank], output_device=args.local_rank)
       cudnn.benchmark = True  # enable cudnn
       ... 
       for epoch in range(start_epoch, args.epochs):
           train_sampler.set_epoch(epoch)
           train_loss, train_acc, batch_time = train(epoch, net, train_loader, criterion, optimizer,       
               warmup_scheduler, scaler)
     ...

def train(epoch, net, train_loader, criterion, optimizer, scheduler, scaler):
    """Train for one epoch."""
    net.train()
    for batch_idx, (images, targets) in enumerate(train_loader):
        with autocast():
            images, targets = images.cuda(), targets.cuda()
            logits = net(images)
            loss = criterion(logits, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        ...

AMP shouldn’t use more memory and I assume you are trying to use the global batch size in each process and thus GPU.
As explained here you should set the batch_size for each GPU as the local batch size (by dividing the global batch size by the number of GPUs).