Hello,
I’m running a PyTorch DDP job across multiple GPUs, and my training loop includes a validation phase after each epoch. The validation completes successfully, but right when I return to training on the next epoch, I get a CUDA OOM error on rank 0. A memory check shows that rank 0 is still holding around 40GB of GPU memory, whereas the other GPUs have freed most of their memory after validation.
I tried clearing caches (torch.cuda.empty_cache()) and calling gc.collect() at the end of validation, but it still appears that rank 0 retains a large portion of memory. Has anyone experienced a similar issue or have suggestions on what could be causing rank 0 to hold onto so much memory after validation?
Some extra details:
I’m using torch version 2.5.1+cu121, and distributed training on 8 H-200 cards.
My validation loop uses with torch.no_grad(): and calls model.eval().
The memory usage spike appears specifically on rank 0; other ranks free up memory as expected.
If I skip validation, the training progresses fine but with 80 G memory used in rank 0 and 40 G used in other cards. So it’s 40 G leftover from validation + 40 G extra used in training in rank 0.
The OOM error occurs when I attempt to resume training after validation finishes.
Any ideas on how to fix or debug this? Thanks in advance!
Thanks for confirming! In this case I would recommend focusing on debugging why the training is already showing this large increase in memory usage since DDP is supposed to use the same memory on all ranks (unless you increase e.g. the batch size of one rank which would be a non-standard approach).
Thanks for the test run!
I found one mistake that I made in posting: I am using ‘gloo’ as the backend instead as ‘nccl’ raises some error. The adjusted code is here:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
class SimpleDataset(Dataset):
def __init__(self, size=100):
self.size = size
self.data = torch.randn(size, 10)
self.labels = torch.randint(0, 2, (size,))
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(20, 2)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
dataset = SimpleDataset()
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)
model = SimpleModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(ddp_model.parameters(), lr=0.01)
num_epochs = 10
for epoch in range(num_epochs):
sampler.set_epoch(epoch)
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = ddp_model(inputs.to(rank))
loss = criterion(outputs, labels.to(rank))
loss.backward()
optimizer.step()
print(f"Rank {rank}, Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")
cleanup()
if __name__ == '__main__':
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
Maybe using the gloo backend is the problem by itself?
Thank you very much for the help!