Extra memory load while using DDP in rank 0, not cleared after validation

Hello,
I’m running a PyTorch DDP job across multiple GPUs, and my training loop includes a validation phase after each epoch. The validation completes successfully, but right when I return to training on the next epoch, I get a CUDA OOM error on rank 0. A memory check shows that rank 0 is still holding around 40GB of GPU memory, whereas the other GPUs have freed most of their memory after validation.

I tried clearing caches (torch.cuda.empty_cache()) and calling gc.collect() at the end of validation, but it still appears that rank 0 retains a large portion of memory. Has anyone experienced a similar issue or have suggestions on what could be causing rank 0 to hold onto so much memory after validation?

Some extra details:

  • I’m using torch version 2.5.1+cu121, and distributed training on 8 H-200 cards.
  • My validation loop uses with torch.no_grad(): and calls model.eval().
  • The memory usage spike appears specifically on rank 0; other ranks free up memory as expected.
  • If I skip validation, the training progresses fine but with 80 G memory used in rank 0 and 40 G used in other cards. So it’s 40 G leftover from validation + 40 G extra used in training in rank 0.
  • The OOM error occurs when I attempt to resume training after validation finishes.

Any ideas on how to fix or debug this? Thanks in advance!

I don’t understand this point as it seems even without the validation run your are seeing 2x memory usage on GPU0?

I don’t understand this point as it seems even without the validation run your are seeing 2x memory usage on GPU0?

Yes, sorry for the unclear description here. Without validation run there is extra memory usage as well.

Thanks for confirming! In this case I would recommend focusing on debugging why the training is already showing this large increase in memory usage since DDP is supposed to use the same memory on all ranks (unless you increase e.g. the batch size of one rank which would be a non-standard approach).

Thanks for the suggestions! I have tried to reduce the code, and I have got a minimal dummy example here:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

class SimpleDataset(Dataset):
    def __init__(self, size=100):
        self.size = size
        self.data = torch.randn(size, 10)
        self.labels = torch.randint(0, 2, (size,))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(20, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)

    dataset = SimpleDataset()
    sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)

    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(ddp_model.parameters(), lr=0.01)

    num_epochs = 10
    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = ddp_model(inputs.to(rank))
            loss = criterion(outputs, labels.to(rank))
            loss.backward()
            optimizer.step()
        print(f"Rank {rank}, Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

    cleanup()

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

This will put about 4.5G memory load on cuda:0, and 884 MB on other cards. Any ideas about what could be wrong here?

Thank you for the code!
I cannot reproduce any issues and see this memory usage:

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   2432443      C   /usr/bin/python                             680MiB |
|    1   N/A  N/A   2432444      C   /usr/bin/python                             680MiB |
|    2   N/A  N/A   2432445      C   /usr/bin/python                             680MiB |
|    3   N/A  N/A   2432446      C   /usr/bin/python                             680MiB |
|    4   N/A  N/A   2432447      C   /usr/bin/python                             680MiB |
|    5   N/A  N/A   2432448      C   /usr/bin/python                             680MiB |
|    6   N/A  N/A   2432449      C   /usr/bin/python                             680MiB |
|    7   N/A  N/A   2432450      C   /usr/bin/python                             680MiB |
+---------------------------------------------------------------------------------------+

Thanks for the test run!
I found one mistake that I made in posting: I am using ‘gloo’ as the backend instead as ‘nccl’ raises some error. The adjusted code is here:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

class SimpleDataset(Dataset):
    def __init__(self, size=100):
        self.size = size
        self.data = torch.randn(size, 10)
        self.labels = torch.randint(0, 2, (size,))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(20, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)

    dataset = SimpleDataset()
    sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)

    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(ddp_model.parameters(), lr=0.01)

    num_epochs = 10
    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = ddp_model(inputs.to(rank))
            loss = criterion(outputs, labels.to(rank))
            loss.backward()
            optimizer.step()
        print(f"Rank {rank}, Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

    cleanup()

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

Maybe using the gloo backend is the problem by itself?
Thank you very much for the help!

Edit: I have noticed that there are several 520 MB extra processes, all on cuda:0

Thu Mar 13 18:39:42 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H200                    Off |   00000000:53:00.0 Off |                    0 |
| N/A   29C    P0            120W /  700W |    4345MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H200                    Off |   00000000:64:00.0 Off |                    0 |
| N/A   31C    P0            118W /  700W |     680MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H200                    Off |   00000000:75:00.0 Off |                    0 |
| N/A   31C    P0            120W /  700W |     680MiB / 143771MiB |      1%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H200                    Off |   00000000:86:00.0 Off |                    0 |
| N/A   32C    P0            121W /  700W |     680MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H200                    Off |   00000000:97:00.0 Off |                    0 |
| N/A   32C    P0            121W /  700W |     680MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H200                    Off |   00000000:A8:00.0 Off |                    0 |
| N/A   30C    P0            119W /  700W |     680MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H200                    Off |   00000000:B9:00.0 Off |                    0 |
| N/A   31C    P0            117W /  700W |     680MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H200                    Off |   00000000:CA:00.0 Off |                    0 |
| N/A   31C    P0            120W /  700W |     680MiB / 143771MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     15102      C   ...niconda/envs/pytorch_env/bin/python        672MiB |
|    0   N/A  N/A     15103      C   ...niconda/envs/pytorch_env/bin/python        520MiB |
|    0   N/A  N/A     15104      C   ...niconda/envs/pytorch_env/bin/python        520MiB |
|    0   N/A  N/A     15105      C   ...niconda/envs/pytorch_env/bin/python        520MiB |
|    0   N/A  N/A     15106      C   ...niconda/envs/pytorch_env/bin/python        520MiB |
|    0   N/A  N/A     15107      C   ...niconda/envs/pytorch_env/bin/python        520MiB |
|    0   N/A  N/A     15108      C   ...niconda/envs/pytorch_env/bin/python        520MiB |
|    0   N/A  N/A     15109      C   ...niconda/envs/pytorch_env/bin/python        520MiB |
|    1   N/A  N/A     15103      C   ...niconda/envs/pytorch_env/bin/python        672MiB |
|    2   N/A  N/A     15104      C   ...niconda/envs/pytorch_env/bin/python        672MiB |
|    3   N/A  N/A     15105      C   ...niconda/envs/pytorch_env/bin/python        672MiB |
|    4   N/A  N/A     15106      C   ...niconda/envs/pytorch_env/bin/python        672MiB |
|    5   N/A  N/A     15107      C   ...niconda/envs/pytorch_env/bin/python        672MiB |
|    6   N/A  N/A     15108      C   ...niconda/envs/pytorch_env/bin/python        672MiB |
|    7   N/A  N/A     15109      C   ...niconda/envs/pytorch_env/bin/python        672MiB |
+-----------------------------------------------------------------------------------------+

By adding the following statement before “setup”, we have the issue resolved:

torch.cuda.set_device(rank)
setup(rank, world_size)