Memory leak of rank 0 in distributed training

Hi, I come across a problem when I try to train my model in a distributed way.

In summary, the problem is how to solve the memory leak in rank 0.

When I run the code in single gpu, it works well and occupies 10.3G GPU memory. My GPU is 2080ti 11GB. But when I run the code in the distributed way, OOM occurred.
I build a class named Trainer, then initiate dataset and model inside. The rough code is showed below. The process contains several backward operations.

class Trainer():
    def __init__(self):
        self.data_initial()
        self.model_initial()
        self.train()

   def data_initial(self):
        source_data = DataSet(.....)
        source_sampler = data.distributed.DistributedSampler(source_data, seed=1234)
        source_dataloader = data.DataLoader(source_data, batch_size=..., num_workers=4,
            pin_memory=False, sampler=source_sampler)
        self.source_data = enumerate(source_dataloader)

        target_data = DataSet(.....)
        target_sampler = data.distributed.DistributedSampler(target_data, seed=1234)
        target_dataloader = data.DataLoader(source_data, batch_size=..., num_workers=4,
            pin_memory=False, sampler=target_sampler)
        self.target_data = enumerate(target_dataloader)

    def model_initial(self):
        # build backbone
        rank = dist.get_rank()
        self.backbone = ResNet().cuda(rank)
        self.backbone = DDP(self.backbone, device_ids=[rank])
        
        # restore_part
        the parameter succeeds restore. I have checked it.
        self.backbone.train()

        # classifier
        self.classifier = Classifier(...).cuda(rank)
        self.classifier = DDP(self.classifier, device_ids=[rank])
        self.classifier.train()

        # optimizer
        self.backbone_optimizer = optim.SGD(self.backbone.parameters(), ...)
        self.backbone_optimizer.zero_grad()

        self.classifier_optimizer = optim.Adam(self.classifier.parameters(), ...)
        self.classifier_optimizer.zero_grad()

   def train(self):
        # pytorch prework
        rank = dist.get_rank()
        self.criterion = torch.nn.BCEWithLogitsLoss().cuda(rank)
        
        for i in range(1, self.config.num_steps):
            # get data
            c, batch = self.source_data.__next__()
            image, label, _, _ = batch
            _, batch = self.target_data.__next__()
            image_t, label_t, _, _ = batch

            self.step(i, image, label, image_t, label_t, loss_dic)

            gc.collect()

def step(self, i, image, label, image_t, label_t, loss_dic):
        rank = dist.get_rank()

        self.backbone_optimizer.zero_grad()
        self.classifier_optimizer.zero_grad()

        # supervised learning for source
        image = Variable(image).cuda(rank)
        x = self.backbone(image)
        y1, _ = self.classifier(x)
        loss = self.criterion(y, label.long().cuda(rank))
        loss.backward()
      

        for para in self.classifier.parameters():
            para.requires_grad = False
        image_t = Variable(image_t).cuda(rank)
        x = self.backbone(image_t)
        _, y2 = self.classifier(x)

        label2 = Variable(...).cuda(rank)
        loss = self.criterion(y2, label2)
        loss.backward()

        ###
        # optimize the parameter
        self.backbone_optimizer.step()
        self.classifier_optimizer.step()

        ###
        # recycle variable
        #delete some variable in intermedia process
        del x .....
        torch.cuda.empty_cache()

My main function to call distributed training is showed below.

def main():
    world_size = torch.cuda.device_count()
    mp.spawn(sub_process, args=(world_size), nprocs=world_size, join=True)

def sub_process(rank, world_size):
    set_up(rank, world_size)
    trainer = Trainer()
    cleanup()

def set_up(rank, world_size):
    os.environ['MASTER_ADDR'] = '127.0.0.113'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size) 

def cleanup():
    dist.destroy_process_group()

So in process 0(rank 0), the first step training is ok, but in the second step. OOM occurred.

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/xx/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
.......
File "/home/xx/code/xxx.py", line 225, in step
    loss_aux = self.seg_criterion(out_aux, label.long().cuda(rank))
  File "/home/xx/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/xx/code/multitask/utils/utils.py", line 44, in forward
    predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
RuntimeError: CUDA out of memory. Tried to allocate 536.00 MiB (GPU 0; 10.73 GiB total capacity; 8.30 GiB already allocated; 198.56 MiB free; 9.01 GiB reserved in total by PyTorch)

I know that maybe the memory leak occurred. But I have no experience about distributed one. How should I solve this problem

Does anyone know how to solve it? It’s a little emergent. In summary, how to solve the memory leak in rank 0.

Hey @JIE_LIU, which version of PyTorch are you using? If it is v1.6, I suspect it is due to the DDP comm bucket reconstruction algorithm temporarily boost memory consumption, and then hit the OOM problem.

cc @Yanli_Zhao

sure, it’s 1.6.0. How can I avoid this situation?

I don’t think there is a decent way to get rid of it in v1.6. One hacky solution might be introducing a tiny unused parameter in the model (e.g., self.unused = nn.Linear(1, 1)), and then set find_unused_parameters=True in DDP ctor, which would disable bucket rebuilt.

This looks like a regression to me. @Yanli_Zhao has a some recent work to reduce DDP memory footprint and hopefully that can help.

BTW, if possible, can you try the same script with PyTorch v1.5, it will help to confirm if bucket reconstruction is indeed the culprit.

Thanks for your solution. I’ll try this later and give you feedback

1 Like