DistributedDataParallel imbalanced GPU memory usage

Hi there,
I’m trying to train my network wrapped with DistributedDataParallel on a single machine with 4 GPUs. It went smoothly until the 43rd epoch. The training process was interrupted by CUDA out of memory error on GPU 2.

Traceback (most recent call last):
  File "train_ddp.py", line 247, in <module>
    trainer.training(epoch)
  File "train_ddp.py", line 171, in training
    iter_loss.backward()
  File "/scratch/workspace/zsding/anaconda3/lib/python3.6/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/scratch/workspace/zsding/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 752.00 MiB (GPU 2; 15.77 GiB total capacity; 10.24 GiB already allocated; 518.25 MiB free; 785.63 MiB cached)

Then I shrank the input size and resumed from my previous weight to try to debug the memory footprint. The chart below shows that there were three extra python threads running and occupying 1080 mib
on GPU 2. And I find that they shared same PID with the threads on other GPUs.


And of course, each GPU has only one thread during the first training epoch. No GPU specific operation (like .to(2)) used in my train script, but I applied SyncBatchNorm on my model (can it be the reason?).
How can I figure out what are those three threads? Could you provide some solutions to solve this problem?

Thanks!

Hi, Is it possible for you to provide a snippet of your code/a way to reproduce the issue that you are seeing? Similar to DataParallel imbalanced memory usage, it could be the case that the outputs of your forward pass are being gathered onto a single GPU (GPU 2 in your case), causing it to OOM.

Hi, thanks for replying.
Here is the forward pass part. I don’t know if it is helpful.

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.trainloader)
        for i, (image, target) in enumerate(tbar):
            image, target = image.to(self.device), target.to(self.device)
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            outputs = self.model(image)

            # multi-scale training
            iter_loss = 0
            for logit in outputs:
                _, _, H, W = logit.shape
                labels_ = utils.resize_labels(target, size=(H, W))
                iter_loss += self.criterion(logit.cuda(), labels_.cuda())

            torch.cuda.empty_cache()
            iter_loss.backward()
            self.optimizer.step()
            train_loss += iter_loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

And there is no gather operation explicitly used in my code.

Hi, thanks for your reply!
Here is the forward pass part in the script. I don’t know whether it is helpful.

def training(self, epoch):
    train_loss = 0.0
    self.model.train()
    tbar = tqdm(self.trainloader)
    for i, (image, target) in enumerate(tbar):
        image, target = image.to(self.device), target.to(self.device)
        self.scheduler(self.optimizer, i, epoch, self.best_pred)
        self.optimizer.zero_grad()
        outputs = self.model(image)

        # Multi-size training
        iter_loss = 0
        for logit in outputs:
            _, _, H, W = logit.shape
            labels_ = utils.resize_labels(target, size=(H, W))
            iter_loss += self.criterion(logit.cuda(), labels_.cuda())

        torch.cuda.empty_cache()
        iter_loss.backward()
        self.optimizer.step()
        train_loss += iter_loss.item()
        tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

And I’ve not applied gather function (like torch.nn.parallel.scatter_gather.gather) explicitly.

thanks!

Sorry for the late reply here. Just to confirm, are you spawning a single process per device (gpu)?

A bit late here, but I had the exact same issue and the problem was that I was loading a state_dict (saved from the device cuda:0) from four different GPUs, and the resulting effect was that all the GPUs were loading the state_dict in the device cuda:0.

I solved loading the state_dict with:
torch.load(<state dict file path>, map_location=current_device)