Memory footprint for single-node multi-GPU setting (using DistributedDataParallel)

Hello, I have this minimal working code (using DistributedDataParallel)

import logging
import os

import gpustat
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import transformers
from torch.nn.parallel import DistributedDataParallel

logger = logging.getLogger(__name__)

class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(10000, 2000)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(2000, 100)
        self.softmax = torch.nn.Softmax()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x

def setup_and_train(rank, num_gpus):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    dist.init_process_group("nccl", rank=rank, world_size=num_gpus)
    device = rank
    logging.basicConfig(
        level=logging.INFO,
        format=f"%(asctime)s.%(msecs)03d %(levelname)s Rank-{rank} %(module)s - %(pathname)s:%(lineno)d: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # setup agent model
    logger.info(gpustat.new_query())
    tiny_model = TinyModel().to(device)
    logger.info(device)
    logger.info(gpustat.new_query())

    tiny_model = DistributedDataParallel(tiny_model, device_ids=[rank], output_device=rank)

    logger.info(gpustat.new_query())

    try:
        while True:
            pass
    except KeyboardInterrupt:
        logger.info("Received SIGTERM.")
    if num_gpus > 1:
        dist.destroy_process_group()


if __name__ == "__main__":
    num_gpus = torch.cuda.device_count()
    mp.spawn(setup_and_train, nprocs=num_gpus, args=(num_gpus,), join=True)

where I want to simulate single-node multi-GPU training but I have a memory problem running this code. I am using 8 GPU (16GB memory each GPU node) instance and when I run this simple code (without any training or data manipulation) I get following report running nvidia-smi

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    124274      C   ...vs/pytorch_p38/bin/python     1851MiB |
|    0   N/A  N/A    124275      C   ...vs/pytorch_p38/bin/python     1245MiB |
|    0   N/A  N/A    124276      C   ...vs/pytorch_p38/bin/python     1245MiB |
|    0   N/A  N/A    124277      C   ...vs/pytorch_p38/bin/python     1243MiB |
|    0   N/A  N/A    124278      C   ...vs/pytorch_p38/bin/python     1243MiB |
|    0   N/A  N/A    124279      C   ...vs/pytorch_p38/bin/python     1243MiB |
|    0   N/A  N/A    124280      C   ...vs/pytorch_p38/bin/python     1245MiB |
|    0   N/A  N/A    124281      C   ...vs/pytorch_p38/bin/python     1245MiB |
|    1   N/A  N/A    124275      C   ...vs/pytorch_p38/bin/python     1879MiB |
|    2   N/A  N/A    124276      C   ...vs/pytorch_p38/bin/python     1783MiB |
|    3   N/A  N/A    124277      C   ...vs/pytorch_p38/bin/python     1785MiB |
|    4   N/A  N/A    124278      C   ...vs/pytorch_p38/bin/python     1857MiB |
|    5   N/A  N/A    124279      C   ...vs/pytorch_p38/bin/python     1873MiB |
|    6   N/A  N/A    124280      C   ...vs/pytorch_p38/bin/python     1827MiB |
|    7   N/A  N/A    124281      C   ...vs/pytorch_p38/bin/python     1803MiB |
+-----------------------------------------------------------------------------+

I can see that on GPU 0 is also memory footprint from all others GPU (because they share same PID). The problem is that when I start to load the data then the first GPU crash with OutOfMemory error and then I need to have very low batch size which then mitigate the performance boost. Is there something wrong with usage of DistributedDataParallel? Or is it a bug? Or in a worst case, is this behavior normal?

In case you are using CUDA 12, you might need to cherry-pick this PR for your source build.

No, it is

NVIDIA-SMI 450.142.00   Driver Version: 450.142.00   CUDA Version: 11.0  

but I will check that anyway.

In that case your code might have an issue and would explicitly create the unneeded CUDA contexts on the default device, which is cuda:0.

But the code, shown in snippet above, is everything I need to reproduce this behavior. There is no allocation of any tensor explicitly to any device and only what is used is parts of code used in examples how to use DistributedDataParallel in Pytorch.

I cannot reproduce the issue using your code and see the expected single CUDA context per device:

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A   2262945      C   /usr/bin/python                  1342MiB |
|    1   N/A  N/A   2262946      C   /usr/bin/python                  1564MiB |
|    2   N/A  N/A   2262947      C   /usr/bin/python                  1564MiB |
|    3   N/A  N/A   2262948      C   /usr/bin/python                  1564MiB |
|    4   N/A  N/A   2262949      C   /usr/bin/python                  1564MiB |
|    5   N/A  N/A   2262950      C   /usr/bin/python                  1564MiB |
|    6   N/A  N/A   2262951      C   /usr/bin/python                  1564MiB |
|    7   N/A  N/A   2262952      C   /usr/bin/python                  1420MiB |
+-----------------------------------------------------------------------------+

I found what was the original problem - not updated library (sorry for that). I was using torch=1.10.0 and after update to 1.12.1 the problem with that memory disappear. But I was digging deeper because I still have that behavior in my original code even after update (not this snippet but bigger version) and I found that when I was trying to solve the problem I added

torch.cuda.empty_cache()

at some point in the code. When I remove it, the behavior was normal. So, I can reproduce that behavior but not sure if you still want to investigate it:

import logging
import os

import gpustat
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel

logger = logging.getLogger(__name__)

class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(10000, 2000)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(2000, 100)
        self.softmax = torch.nn.Softmax()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x

def setup_and_train(rank, num_gpus):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    dist.init_process_group("nccl", rank=rank, world_size=num_gpus)
    device = rank
    logging.basicConfig(
        level=logging.INFO,
        format=f"%(asctime)s.%(msecs)03d %(levelname)s Rank-{rank} %(module)s - %(pathname)s:%(lineno)d: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # setup agent model
    logger.info(gpustat.new_query())
    tiny_model = TinyModel().to(device)
    logger.info(device)
    logger.info(gpustat.new_query())

    tiny_model = DistributedDataParallel(tiny_model, device_ids=[rank], output_device=rank)
    torch.cuda.empty_cache()
    logger.info(gpustat.new_query())

    try:
        while True:
            pass
    except KeyboardInterrupt:
        logger.info("Received SIGTERM.")
    if num_gpus > 1:
        dist.destroy_process_group()


if __name__ == "__main__":
    num_gpus = torch.cuda.device_count()
    mp.spawn(setup_and_train, nprocs=num_gpus, args=(num_gpus,), join=True)

Thanks for the update.
I still cannot reproduce the issue using your updated code snippet with empty_cache:

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A   2656619      C   /usr/bin/python                  1342MiB |
|    1   N/A  N/A   2656620      C   /usr/bin/python                  1486MiB |
|    2   N/A  N/A   2656621      C   /usr/bin/python                  1486MiB |
|    3   N/A  N/A   2656622      C   /usr/bin/python                  1486MiB |
|    4   N/A  N/A   2656623      C   /usr/bin/python                  1486MiB |
|    5   N/A  N/A   2656624      C   /usr/bin/python                  1486MiB |
|    6   N/A  N/A   2656625      C   /usr/bin/python                  1486MiB |
|    7   N/A  N/A   2656626      C   /usr/bin/python                  1342MiB |
+-----------------------------------------------------------------------------+