Distributed training, all procs reserve memory on GPU 0

For demonstration, here only 2 processes. You see it in nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:61:00.0 Off |                    0 |
| N/A   39C    P0    57W / 300W |   3433MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:62:00.0 Off |                    0 |
| N/A   36C    P0    58W / 300W |   2615MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    231620      C   ...0-torch2.1/bin/python3.10     2612MiB |
|    0   N/A  N/A    231621      C   ...0-torch2.1/bin/python3.10      818MiB |
|    1   N/A  N/A    231621      C   ...0-torch2.1/bin/python3.10     2612MiB |
+-----------------------------------------------------------------------------+

You see that proc 231621 also reserved some memory on GPU 0.
In my actual use case, I wanted to use 16 GPUs (DGX with V100). And every worker reserved some memory on GPU 0, which then caused an OOM, because not so much memory was left anymore:
OutOfMemoryError: CUDA out of memory. Tried to allocate 52.00 MiB. GPU 0 has a total capacty of 31.74 GiB of which 39.62 MiB is free. Process 101553 has 1.19 GiB memory in use. Process 101552 has 1.19 GiB memory in use. Process 101563 has 1.19 GiB memory in use. Process 101549 has 1.19 GiB memory in use. Process 101555 has 1.19 GiB memory in use. Process 101558 has 1.19 GiB memory in use. Process 101557 has 1.19 GiB memory in use. Process 101562 has 1.19 GiB memory in use. Process 101556 has 1.19 GiB memory in use. Process 101560 has 1.19 GiB memory in use. Process 101554 has 1.19 GiB memory in use. Process 101550 has 1.19 GiB memory in use. Process 101561 has 1.19 GiB memory in use. Process 101559 has 1.19 GiB memory in use. Process 101551 has 1.19 GiB memory in use. Including non-PyTorch memory, this process has 13.82 GiB memory in use. Of the allocated memory 11.54 GiB is allocated by PyTorch, and 677.26 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
So, out of the 32GB this GPU has (V100), 15*1.19=17.85GB is reserved by other procs, so only about 14GB is left then, which is too less for my actual use case.

As you see, this is obviously a problem. Is this expected behavior? Am I doing sth wrong?

More debugging details here: Torch distributed: Every worker reserves memory on GPU 0 · Issue #1469 · rwth-i6/returnn · GitHub

I actually traced back to when this memory reservation happen: It happens in DistributedDataParallel.__init__ inside _verify_param_shape_across_processes. Before that line, there is no reserved memory on GPU 0 by rank 1, and after that line, I see the memory in nvidia-smi.

Your processes are all initializing a CUDA context on the default device. This can be avoided by setting the device via torch.cuda.set_device. Could you post a minimal and executable code snippet reproducing the issue, if you get stuck and it doesn’t help?

Here is some example code: (Online: https://github.com/albertz/playground/blob/master/torch-distributed-demo.py)

"""
run::
    python -m torch.distributed.run --standalone --nnodes 1 --nproc-per-node=2 torch-distributed-demo.py

https://pytorch.org/docs/stable/notes/ddp.html
"""

import os
import sys
import time
import subprocess as sp
import torch
from torch import nn
from torch import optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel


def _debug_mem(msg):
    if local_rank == 1:
        print(f"*** {msg} {{")
        sp.call(
            f"(nvidia-smi; echo '*** {msg} -- {os.getpid()} }}'; ) | grep {os.getpid()}",
            shell=True,
            stdout=sys.stdout,
            stderr=sys.stdout,
        )
        sys.stdout.flush()


dist.init_process_group(backend=None)  # nccl + gloo
local_rank = int(os.environ["LOCAL_RANK"])
local_size = int(os.environ["LOCAL_WORLD_SIZE"])
dev = torch.device(f"cuda:{local_rank}")
print(f"Start running torch distributed training on local rank {local_rank}/{local_size}.")
_debug_mem("start")


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc(x)


model = Model()
model.to(dev)
_debug_mem("after model init")

ddp_model = DistributedDataParallel(model, device_ids=[local_rank])
_debug_mem("after DDP wrapping")

# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
_debug_mem("after optimizer init")

step = 0
while True:
    # forward pass
    outputs = ddp_model(torch.randn(20, 10, device=dev))
    labels = torch.randn(20, 10, device=dev)
    # backward pass
    loss_fn(outputs, labels).backward()
    # update parameters
    optimizer.step()

    print(f"[{local_rank}] step {step}")
    _debug_mem(f"step {step}")

    if step >= 3:
        break
    time.sleep(0.5)
    step += 1

I run this via: python -m torch.distributed.run --standalone --nnodes 1 --nproc-per-node=2 torch-distributed-demo.py.

The environment:

Wed Nov 29 17:26:38 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:61:00.0 Off |                    0 |
| N/A   38C    P0    57W / 300W |   2953MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:62:00.0 Off |                    0 |
| N/A   35C    P0    57W / 300W |   2135MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

Output (with NCCL_DEBUG=INFO, accidentally also with NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1, but that has no influence here):

Start running torch distributed training on local rank 0/2.                                                                                    
Start running torch distributed training on local rank 1/2.                                                                                    
*** start {                                                                                                                                    
*** start -- 62563 }                                                                                                                           
*** after model init {                                                                                                                         
|    1   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      752MiB |
ncg12:62562:62562 [0] NCCL INFO Bootstrap : Using ib0:192.168.8.11<0>  
*** after model init -- 62563 } 
ncg12:62562:62562 [0] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such f
ile or directory
ncg12:62562:62562 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
ncg12:62562:62562 [0] NCCL INFO cudaDriverVersion 12000
NCCL version 2.18.1+cuda12.1
ncg12:62562:62597 [0] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
ncg12:62562:62597 [0] NCCL INFO NET/Socket : Using [0]ib0:192.168.8.11<0>
ncg12:62562:62597 [0] NCCL INFO Using network Socket
ncg12:62563:62563 [1] NCCL INFO cudaDriverVersion 12000
ncg12:62563:62563 [1] NCCL INFO Bootstrap : Using ib0:192.168.8.11<0>
ncg12:62563:62563 [1] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such f
ile or directory
ncg12:62563:62563 [1] NCCL INFO NET/Plugin : No plugin found, using internal implementation
ncg12:62563:62598 [1] NCCL INFO NCCL_IB_DISABLE set by environment to 1.
ncg12:62563:62598 [1] NCCL INFO NET/Socket : Using [0]ib0:192.168.8.11<0>
ncg12:62563:62598 [1] NCCL INFO Using network Socket
ncg12:62563:62598 [1] NCCL INFO NCCL_P2P_LEVEL set by environment to LOC
ncg12:62562:62597 [0] NCCL INFO NCCL_P2P_LEVEL set by environment to LOC
ncg12:62562:62597 [0] NCCL INFO Channel 00/02 :    0   1
ncg12:62562:62597 [0] NCCL INFO Channel 01/02 :    0   1
ncg12:62562:62597 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1                                                                    
ncg12:62562:62597 [0] NCCL INFO P2P Chunksize set to 131072
ncg12:62563:62598 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0                                                                    
ncg12:62563:62598 [1] NCCL INFO P2P Chunksize set to 131072                                                                                    
ncg12:62563:62598 [1] NCCL INFO Channel 00 : 1[62000] -> 0[61000] via SHM/direct/direct                                                        
ncg12:62563:62598 [1] NCCL INFO Channel 01 : 1[62000] -> 0[61000] via SHM/direct/direct                                                        
ncg12:62562:62597 [0] NCCL INFO Channel 00 : 0[61000] -> 1[62000] via SHM/direct/direct                                                        
ncg12:62562:62597 [0] NCCL INFO Channel 01 : 0[61000] -> 1[62000] via SHM/direct/direct                                                        
ncg12:62562:62597 [0] NCCL INFO Connected all rings                                                                                            
ncg12:62562:62597 [0] NCCL INFO Connected all trees                                                                                            
ncg12:62562:62597 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512                                                                  
ncg12:62562:62597 [0] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer                                      
ncg12:62563:62598 [1] NCCL INFO Connected all rings                                                                                            
ncg12:62563:62598 [1] NCCL INFO Connected all trees                                                                                            
ncg12:62563:62598 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512                                                                  
ncg12:62563:62598 [1] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer                                      
ncg12:62562:62597 [0] NCCL INFO comm 0x39b60860 rank 0 nranks 2 cudaDev 0 busId 61000 commId 0x837bd98b46a5d74c - Init COMPLETE                
ncg12:62563:62598 [1] NCCL INFO comm 0x50da4860 rank 1 nranks 2 cudaDev 1 busId 62000 commId 0x837bd98b46a5d74c - Init COMPLETE                
*** after DDP wrapping {                                                                                                                        
|    0   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      818MiB |                                                                
|    1   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      846MiB |                                                                
*** after DDP wrapping -- 62563 }                                                                                                              
*** after optimizer init {                                                                                                                     
[0] step 0                                                                                                                                     
|    0   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      818MiB |                                                                
|    1   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      846MiB |                                                                
*** after optimizer init -- 62563 }                                            
[1] step 0                                                           
*** step 0 {                    
[0] step 1                                                                                                                                      
|    0   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      966MiB |
|    1   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10     1030MiB |            
*** step 0 -- 62563 }
[1] step 1                                                                                                                                      
*** step 1 {                                                                                                                                    
[0] step 2                                                                                                                                      
|    0   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      966MiB |                                                                 
|    1   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10     1030MiB |                                                                
*** step 1 -- 62563 }                                      
[1] step 2                                                                                                                                     
*** step 2 {                                                                                                                                   
[0] step 3                                                                                                                                     
|    0   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      966MiB |                                                                
|    1   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10     1030MiB |                                                                
*** step 2 -- 62563 }                                                                                                                          
[1] step 3                                                                                                                                     
*** step 3 {                                                                                                                                   
|    0   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10      966MiB |                                                                
|    1   N/A  N/A     62563      C   ...0-torch2.1/bin/python3.10     1030MiB |                                                                
*** step 3 -- 62563 }                            

Oh, adding torch.cuda.set_device really fixed the problem! But this looks like a bug, right? You see that the model is on the right device, and then inside the DDP wrapping module, this happens, more specifically, the memory on the wrong device gets reserved inside of _verify_param_shape_across_processes.

Btw, torch.cuda.max_memory_reserved(0) always returned 0 at all times. This also seems wrong?