DDP SocketTimeout error on Windows

Hi everyone,
i’ve developed this little POC using pytorch distributed package: essentially a Trainer spawns N processes and orchestrate them using python Pipes (it could also be Queues). Normally it should send data at every epoch, but in this POC the data is just sent one on process creation. The processes train a model through DDP.

import os
import signal
import socket
from contextlib import closing
from multiprocessing.connection import Connection, Pipe
from typing import List

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDP


def init_process(rank, world_size, ddp_free_port, recv, train_data):
    """Initialize the distributed environment."""
    torch.set_num_threads(1)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = ddp_free_port
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["NODE_RANK"] = "0"
    dist.init_process_group("gloo", init_method=f"tcp://localhost:{ddp_free_port}", rank=rank, world_size=world_size)
    Worker(recv, train_data).train()


class Worker:
    def __init__(self, queue, train_dset):
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        self.queue: Connection = queue
        self.train_dset = train_dset
        self.model = torch.nn.Sequential(nn.Linear(784, 64), torch.nn.ReLU(), torch.nn.Linear(64, 10))
        self.model = DDP(self.model)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

    def train(self):
        loss_fn = nn.CrossEntropyLoss()
        sampler = torch.utils.data.distributed.DistributedSampler(
            self.train_dset, num_replicas=self.world_size, rank=self.rank, shuffle=True
        )
        train_loader = torch.utils.data.DataLoader(self.train_dset, sampler=sampler, batch_size=32)
        while True:
            epoch = self.queue.recv()
            if epoch is False:
                print(f"Rank-{self.rank} done!")
                return
            total_loss = 0
            sampler.set_epoch(epoch)
            for i, batch in enumerate(train_loader):
                images, labels = batch
                out = self.model(images.view(-1, 28 * 28))
                loss = loss_fn(out, labels)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
            dist.barrier()
            if self.rank == 0:
                print(f"Epoch: {epoch}, Loss@rank-{self.rank}: {total_loss / len(train_loader):.4f}")
                print(f"Rank-0 is telling the trainer that everything is done for the epoch {epoch}")
                self.queue.send(True)


class Trainer:
    def __init__(self, world_size: int, epochs: int = 5) -> None:
        self.world_size = world_size
        self.epochs = epochs
        self.train_data = torchvision.datasets.MNIST(
            "/tmp/data",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
        )
        self.test_data = torchvision.datasets.MNIST(
            "/tmp/data",
            train=False,
            download=True,
            transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
        )
        self.queues: List[Connection] = []
        self.ddp_free_port = str(find_free_port())

    def run(self):
        """Run the distributed environment."""
        print("Start training")
        queues = []
        processes = []
        for rank in range(self.world_size):
            if rank == 0:
                recv, send = Pipe(duplex=True)
            else:
                recv, send = Pipe(duplex=False)
            p = mp.Process(
                target=init_process,
                args=(rank, self.world_size, self.ddp_free_port, recv, self.train_data),
                daemon=True,
            )
            p.start()
            queues.append(send)
            processes.append(p.pid)
        self.train(queues, processes)

    def train(self, queues, processes):
        for epoch in range(self.epochs):
            for rank in range(self.world_size):
                queues[rank].send(epoch)
            print("Training waiting for rank-0")
            queues[0].recv()
        for rank in range(self.world_size):
            queues[rank].send(False)
            queues[rank].close()
            os.kill(processes[rank], signal.SIGTERM)


def find_free_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(("", 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]


if __name__ == "__main__":
    os.environ["LOGLEVEL"] = "DEBUG"
    mp.set_start_method("spawn")
    trainer = Trainer(world_size=16)
    trainer.run()
    print("Finished training")

I receive the following error, for every process spawned, randomly, if i increase the number of processes from 16 to 32 for example:

...
Process Process-1:
Traceback (most recent call last):
  File "C:\Program Files (x86)\Python38\lib\multiprocessing\process.py", line 315, in _bootstrap
    self.run()
  File "C:\Program Files (x86)\Python38\lib\multiprocessing\process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "c:\Users\belof\Desktop\temp\examples\ddp_cpu.py", line 27, in init_process
    dist.init_process_group("gloo", init_method=f"tcp://localhost:{ddp_free_port}", rank=rank, world_size=world_size)
  File "C:\Users\belof\Desktop\temp\.venv\lib\site-packages\torch\distributed\distributed_c10d.py", line 602, in init_process_group
    default_pg = _new_process_group_helper(
  File "C:\Users\belof\Desktop\temp\.venv\lib\site-packages\torch\distributed\distributed_c10d.py", line 703, in _new_process_group_helper
    pg = ProcessGroupGloo(prefix_store, rank, world_size, timeout=timeout)
RuntimeError: Socket Timeout
Traceback (most recent call last):
  File "C:\Program Files (x86)\Python38\lib\multiprocessing\connection.py", line 312, in _recv_bytes
    nread, err = ov.GetOverlappedResult(True)
BrokenPipeError: [WinError 109] The pipe has been ended

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:/Users/belof/Desktop/temp/examples/ddp_cpu.py", line 131, in <module>
    trainer.run()
  File "c:/Users/belof/Desktop/temp/examples/ddp_cpu.py", line 106, in run
    self.train(queues, processes)
  File "c:/Users/belof/Desktop/temp/examples/ddp_cpu.py", line 113, in train
    queues[0].recv()
  File "C:\Program Files (x86)\Python38\lib\multiprocessing\connection.py", line 250, in recv
    buf = self._recv_bytes()
  File "C:\Program Files (x86)\Python38\lib\multiprocessing\connection.py", line 321, in _recv_bytes
    raise EOFError
EOFError

It seems to me something related to windows spawn method and the queue references passed to the processes, but i don’t really know what is happening here.
This is the result of the collect_env.py script:

Collecting environment information...
PyTorch version: 1.12.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Pro
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.8 (tags/v3.8.8:024d805, Feb 19 2021, 13:18:16) [MSC v.1928 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19041-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.931
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.3
[pip3] pytorch-lightning==1.6.4
[pip3] torch==1.12.1
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.12.0
[conda] Could not collect

It seems we first looked at this issue on GitHub, let’s continue the discussion there.

1 Like

For anyone to follow: Gloo DDP SocketTimeout error on Windows · Issue #85621 · pytorch/pytorch · GitHub

Has this been resolved?

1 Like