I am encountering a problem while attempting to train a neural network utilizing torch distributed data parallel. The issue is that my system crashes and instantaneously restarts when I train using 4 GPUs. Interestingly, this problem doesn’t occur when I train with 3 GPUs. Solutions offered in the different posts on this forum regarding similar problems did not help either.
Here are some further details on my setup that augment my perplexity:
- I created a fresh Pytorch environment with Python 3.11.3, and the Pytorch versions I tried are 2.0.1 and 1.12.1+cu116. The same crash issue also arose with older versions of Pytorch and Python.
>>> print(torch.__version__)
2.0.1
>>> print(torch.__version__)
1.12.1+cu116
...
...
- I utilized the following torch distributed data parallel example code:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(784, 10000)
self.fc2 = nn.Linear(10000, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.fc2(x)
return x
def train(rank, num_gpus):
dist.init_process_group(
backend="nccl", init_method="env://", world_size=num_gpus, rank=rank
)
torch.cuda.set_device(rank)
model = SimpleNet().to(rank)
ddp_model = DistributedDataParallel(model, device_ids=[rank])
print("Rank ", rank, ", Model Created")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_set = datasets.MNIST("./data", download=True, train=True, transform=transform)
train_sampler = DistributedSampler(
dataset=train_set, num_replicas=num_gpus, rank=rank
)
train_loader = DataLoader(
dataset=train_set,
batch_size=512,
shuffle=False,
num_workers=1,
pin_memory=True,
sampler=train_sampler,
)
criterion = nn.CrossEntropyLoss().to(rank)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
for epoch in range(100):
running_loss = 0.0
for inputs, labels in train_loader:
inputs = inputs.to(rank)
labels = labels.to(rank)
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print("Rank ", rank, ", Epoch ", epoch, ", Loss: ", running_loss)
def main():
num_gpus = 4
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
mp.spawn(train, args=(num_gpus,), nprocs=num_gpus, join=True)
if __name__ == "__main__":
main()
- The output of nvidia-smi can be seen in the attached image.
The operation runs smoothly on three GPUs (the order does not matter - for instance, 0,1,2 or 1,2,3 or 0,2,3). I can even occupy the VRAM completely, after which, as expected, a CUDA out of memory error is thrown. - If I train with four GPUs, everything works up to 1.2GB per GPU. But once that limit is exceeded, my computer suddenly crashes and reboots before any error messages can be displayed.
- If I increase the batch size using 4 GPUs to a point where the memory should be depleted, I can see the four GPUs accumulating all the VRAM up to 12gb until the CUDA out of memory error gets thrown. This leads me to believe that the issue might be related to the GPU communication and the NCCL backend.
- If I switch from the “nccl” backend to the “gloo” backend, the program works flawlessly on all GPUs.
- The result of my NCCL tests is as follows:
(base) ➜ nccl-tests git:(master) mpirun -np 1 ./build/all_reduce_perf -b 8 -e 128M -f 2 -g 4
Invalid MIT-MAGIC-COOKIE-1 keyInvalid MIT-MAGIC-COOKIE-1 key# nThread 1 nGpus 4 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0 #
# Using devices
# Rank 0 Group 0 Pid 412103 on kong device 0 [0x21] NVIDIA GeForce RTX 3080 Ti
# Rank 1 Group 0 Pid 412103 on kong device 1 [0x22] NVIDIA GeForce RTX 3080 Ti
# Rank 2 Group 0 Pid 412103 on kong device 2 [0x41] NVIDIA GeForce RTX 3080 Ti
# Rank 3 Group 0 Pid 412103 on kong device 3 [0x43] NVIDIA GeForce RTX 3080 Ti #
# out-of-place in-place # size count type redop root time algbw busbw #wrong time algbw busbw #wrong # (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s) 8 2 float sum -1 4318.1 0.00 0.00 0 126.5 0.00 0.00 0
16 4 float sum -1 127.6 0.00 0.00 0 340.7 0.00 0.00 0
32 8 float sum -1 125.1 0.00 0.00 0 3120.3 0.00 0.00 0
64 16 float sum -1 129.0 0.00 0.00 0 3220.1 0.00 0.00 0
128 32 float sum -1 129.9 0.00 0.00 0 131.3 0.00 0.00 0
256 64 float sum -1 129.6 0.00 0.00 0 128.4 0.00 0.00 0
512 128 float sum -1 128.3 0.00 0.01 0 130.0 0.00 0.01 0
1024 256 float sum -1 130.1 0.01 0.01 0 129.8 0.01 0.01 0
2048 512 float sum -1 2767.1 0.00 0.00 0 126.9 0.02 0.02 0
4096 1024 float sum -1 131.0 0.03 0.05 0 20.06 0.20 0.31 0
8192 2048 float sum -1 136.5 0.06 0.09 0 133.2 0.06 0.09 0
16384 4096 float sum -1 355.7 0.05 0.07 0 134.8 0.12 0.18 0
32768 8192 float sum -1 140.5 0.23 0.35 0 143.8 0.23 0.34 0
65536 16384 float sum -1 164.0 0.40 0.60 0 159.3 0.41 0.62 0
131072 32768 float sum -1 3877.2 0.03 0.05 0 4068.1 0.03 0.05 0
262144 65536 float sum -1 357.9 0.73 1.10 0 345.7 0.76 1.14 0
524288 131072 float sum -1 596.6 0.88 1.32 0 580.4 0.90 1.35 0
1048576 262144 float sum -1 784.5 1.34 2.00 0 837.0 1.25 1.88 0
2097152 524288 float sum -1 5550.3 0.38 0.57 0 6989.3 0.30 0.45 0
4194304 1048576 float sum -1 4156.0 1.01 1.51 0 8842.0 0.47 0.71 0
8388608 2097152 float sum -1 8536.3 0.98 1.47 0 8681.9 0.97 1.45 0
16777216 4194304 float sum -1 16639 1.01 1.51 0 15646 1.07 1.61 0
33554432 8388608 float sum -1 28481 1.18 1.77 0 29903 1.12 1.68 0
67108864 16777216 float sum -1 57158 1.17 1.76 0 59731 1.12 1.69 0
134217728 33554432 float sum -1 123551 1.09 1.63 0 117790 1.14 1.71 0 #
Out of bounds values : 0 OK # Avg bus bandwidth : 0.623606 #
- Running GPU benchmarks like gpu-burn using 4 GPUs up to 11gb VRAM and all using 100% power supply operates perfectly fine.
- Disabling the IOMMU following these instructions didn’t help either: PCI Access Control Services (ACS).
- A colleague of mine, using the same code within a new Pytorch environment and the same GPUs (4x3080Ti) but with a different motherboard, does not encounter any problems, and the code runs without any issues.
- Both CPU and RAM usage stay within normal limits throughout, so they’re not the culprits.
- Changing the master port didn’t affect the problem.