Hello, I have this minimal working code (using DistributedDataParallel)
import logging
import os
import gpustat
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import transformers
from torch.nn.parallel import DistributedDataParallel
logger = logging.getLogger(__name__)
class TinyModel(torch.nn.Module):
def __init__(self):
super(TinyModel, self).__init__()
self.linear1 = torch.nn.Linear(10000, 2000)
self.activation = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(2000, 100)
self.softmax = torch.nn.Softmax()
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
x = self.softmax(x)
return x
def setup_and_train(rank, num_gpus):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"
dist.init_process_group("nccl", rank=rank, world_size=num_gpus)
device = rank
logging.basicConfig(
level=logging.INFO,
format=f"%(asctime)s.%(msecs)03d %(levelname)s Rank-{rank} %(module)s - %(pathname)s:%(lineno)d: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# setup agent model
logger.info(gpustat.new_query())
tiny_model = TinyModel().to(device)
logger.info(device)
logger.info(gpustat.new_query())
tiny_model = DistributedDataParallel(tiny_model, device_ids=[rank], output_device=rank)
logger.info(gpustat.new_query())
try:
while True:
pass
except KeyboardInterrupt:
logger.info("Received SIGTERM.")
if num_gpus > 1:
dist.destroy_process_group()
if __name__ == "__main__":
num_gpus = torch.cuda.device_count()
mp.spawn(setup_and_train, nprocs=num_gpus, args=(num_gpus,), join=True)
where I want to simulate single-node multi-GPU training but I have a memory problem running this code. I am using 8 GPU (16GB memory each GPU node) instance and when I run this simple code (without any training or data manipulation) I get following report running nvidia-smi
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 124274 C ...vs/pytorch_p38/bin/python 1851MiB |
| 0 N/A N/A 124275 C ...vs/pytorch_p38/bin/python 1245MiB |
| 0 N/A N/A 124276 C ...vs/pytorch_p38/bin/python 1245MiB |
| 0 N/A N/A 124277 C ...vs/pytorch_p38/bin/python 1243MiB |
| 0 N/A N/A 124278 C ...vs/pytorch_p38/bin/python 1243MiB |
| 0 N/A N/A 124279 C ...vs/pytorch_p38/bin/python 1243MiB |
| 0 N/A N/A 124280 C ...vs/pytorch_p38/bin/python 1245MiB |
| 0 N/A N/A 124281 C ...vs/pytorch_p38/bin/python 1245MiB |
| 1 N/A N/A 124275 C ...vs/pytorch_p38/bin/python 1879MiB |
| 2 N/A N/A 124276 C ...vs/pytorch_p38/bin/python 1783MiB |
| 3 N/A N/A 124277 C ...vs/pytorch_p38/bin/python 1785MiB |
| 4 N/A N/A 124278 C ...vs/pytorch_p38/bin/python 1857MiB |
| 5 N/A N/A 124279 C ...vs/pytorch_p38/bin/python 1873MiB |
| 6 N/A N/A 124280 C ...vs/pytorch_p38/bin/python 1827MiB |
| 7 N/A N/A 124281 C ...vs/pytorch_p38/bin/python 1803MiB |
+-----------------------------------------------------------------------------+
I can see that on GPU 0 is also memory footprint from all others GPU (because they share same PID). The problem is that when I start to load the data then the first GPU crash with OutOfMemory error and then I need to have very low batch size which then mitigate the performance boost. Is there something wrong with usage of DistributedDataParallel
? Or is it a bug? Or in a worst case, is this behavior normal?