Error on Node 0: ETIMEDOUT: connection timed out

I am running RPC with 3 nodes. In my code, master node is successfully able to call worker1’s and worker2’s forward functions and get the results back. After that, loss backprop step is executed on the master node, which takes quite some time, due to that I am getting below error on master node,

dist_autograd.backward(context_id, [losses])
RuntimeError: Error on Node 0: ETIMEDOUT: connection timed out

On the worker nodes I am getting following output,

Failed to respond to 'Shutdown Proceed' in time, got error RPCErr:1:RPC ran for more than set timeout (5000 ms) and will now be marked with an error.
[W tensorpipe_agent.cpp:687] RPC agent for worker2 encountered error when sending outgoing     request #92 to master: ETIMEDOUT: connection timed out
<above line many times>
Process Process-1:
[W tensorpipe_agent.cpp:545] RPC agent for worker2 encountered error when reading incoming             request from master: ECONNRESET: connection reset by peer (this is expected to happen during shutdown)

EDIT:

StackTrace:

Process Process-1:
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/pytorch2/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/user/anaconda3/envs/pytorch2/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "main.py", line 70, in workers_init_rpc
    rpc.shutdown()
  File "/home/user/anaconda3/envs/pytorch2/lib/python3.7/site-packages/torch/distributed/rpc/api.py", line 78, in wrapper
    return func(*args, **kwargs)
  File "/home/user/anaconda3/envs/pytorch2/lib/python3.7/site-packages/torch/distributed/rpc/api.py", line 284, in shutdown
    _get_current_rpc_agent().join()
RuntimeError: [/opt/conda/conda-bld/pytorch_1607370156314/work/third_party/gloo/gloo/transport/tcp/pair.cc:575] Connection closed by peer [192.168.13.205]:28380

Init RPC code:

def workers_init_rpc(rank, world_size, options):
    # options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=128,
    #                                         rpc_timeout=0,
    #                                         init_method="tcp://192.168.13.46:2222" )
    print(f'Rank {rank}: Proceed to init rpc')
    rpc.init_rpc(
        f"worker{rank}",
        rank=rank,
        world_size=world_size,
        rpc_backend_options=options
    )
    print(f'Rank: {rank}, rpc init done')

    if rank == 0:
        print('Proceed to run_master')
        run_master()

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = 3
    processes = []
    options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=128, rpc_timeout = 10*60)
    rank = int(sys.argv[1])
    p = mp.Process(target=workers_init_rpc, args=(rank, world_size, options))
    p.start()
    processes.append(p)
    # mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
    for p in processes:
        p.join()

I have tried to increase rpc_timeout parameter in TensorPipeRpcBackendOptions. But it’s not working. How should I keep the connection ON for longer time durations?

in default, after how long the backward timed out? when you increased rpc_timeout to a very large value, how long did the backward time out?

I set rpc_timeout = 3600 (1 hr) and it ran for around 2 mins 11 seconds (after rpc_init) then timed out.

I also got following (After rpc_init) on the workers that I forgot to mention in my question,

Failed to respond to 'Shutdown Proceed' in time, got error RPCErr:1:RPC ran for more than set timeout (5000 ms) and will now be marked with an error.

This gets printed before the response (for workers) I have shown in my question.

To be specific like you asked, dist_autograd timed out in around 58 seconds. I put a timestamp before dist_autograd.backward(context_id, [losses]) then calculated the duration when it threw ETIMEDOUT error.

@Yanli_Zhao I have made a small edit (in stacktrace of workers) in the question please take a look.

Thanks for reporting this @matrix! If possible, could you paste a small repro (i.e. that shows run_master) that results in this error? Would be great to post this to Issues · pytorch/pytorch · GitHub so we can determine if this is an actual bug.

Code to reproduce the error:

import sys
import torch.distributed.rpc as rpc
import torch
import time
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
import torch.nn as nn
from torch.distributed.rpc import RRef
from torch.distributed.optim import DistributedOptimizer
from torch import optim

def _call_method(method, rref, *args, **kwargs):
    r"""
    a helper function to call a method on the given RRef
    """
    return method(rref.local_value(), *args, **kwargs)


def _remote_method(method, rref, *args, **kwargs):
    r"""
    a helper function to run method on the owner of rref and fetch back the
    result using RPC
    """
    return rpc.rpc_sync(
        rref.owner(),
        _call_method,
        args=[method, rref] + list(args),
        kwargs=kwargs
    )

class Net1(nn.Module):
    def __init__(self):
        super(Net1, self).__init__()
        self.layer = nn.Linear(10, 20)

    def parameter_rrefs(self):
        return [RRef(p) for p in self.parameters() if p.requires_grad]

    def forward(self, x):
        return self.layer(x)

class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.layer = nn.Linear(20, 1)

    def parameter_rrefs(self):
        return [RRef(p) for p in self.parameters() if p.requires_grad]

    def forward(self, x):
        return self.layer(x)

class Net(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Net, self).__init__()
        self.encoder_rref = rpc.remote(
            "worker1",
            Net1,
            args = args,
            kwargs = kwargs
        )

        self.decoder_rref = rpc.remote(
            "worker2",
            Net2,
            args = args,
            kwargs = kwargs
        )

    def parameter_rrefs(self):
        remote_params = []
        remote_params.extend(self.encoder_rref.remote().parameter_rrefs().to_here())
        remote_params.extend(self.decoder_rref.remote().parameter_rrefs().to_here())
        return remote_params

    def forward(self, x):
        x = _remote_method(Net1.forward, self.encoder_rref, x)
        x = _remote_method(Net2.forward, self.decoder_rref, x)
        return x

def run_master():
    model = Net()
    opt = DistributedOptimizer(
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )
    for i in range(10):
        with dist_autograd.context() as context_id:
            x = torch.randn(32, 10)
            loss = model(x)
            loss = loss.sum()
            print('Before dist_autograd')
            dist_autograd.backward(context_id, [loss])
            opt.step(context_id)

def workers_init_rpc(rank, world_size, options):
    print(f'Rank {rank}: Proceed to init rpc')
    rpc.init_rpc(
        f"worker{rank}",
        rank=rank,
        world_size=world_size,
        rpc_backend_options=options
    )
    print(f'Rank: {rank}, rpc init done')

    if rank == 0:
        print('Proceed to run_master')
        run_master()
    rpc.shutdown()


if __name__=="__main__":
    world_size = 3
    processes = []

    options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, rpc_timeout = 60*60)
    rank = int(sys.argv[1])

    p = mp.Process(target=workers_init_rpc, args=(rank, world_size, options))

    p.start()

    processes.append(p)

    for p in processes:
        p.join()

Save this in a .py file.

How to run:

Env. variable exports for all of these 3 nodes:

export MASTER_ADDR=<Node 0 IP>
export MASTER_PORT=<Node 0 port>
export GLOO_SOCKET_IFNAME=network interface
export TP_SOCKET_IFNAME=network interface

On Node 0: python <filename> 0
On Node 1: python <filename> 1
On Node 2: python <filename> 2

PyTorch Version: 1.7.1

Hey @matrix, I made some minor edits to the source code and it works for me locally. See the code below. The only problem I noticed with the original code was that TensorPipeRpcBackendOptions is not picklable, so you cannot pass it as multiprocess args. I moved that to workers_init_rpc.

import sys
import torch.distributed.rpc as rpc
import torch
import time
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
import torch.nn as nn
from torch.distributed.rpc import RRef
from torch.distributed.optim import DistributedOptimizer
from torch import optim

import os

def _call_method(method, rref, *args, **kwargs):
    r"""
    a helper function to call a method on the given RRef
    """
    return method(rref.local_value(), *args, **kwargs)


def _remote_method(method, rref, *args, **kwargs):
    r"""
    a helper function to run method on the owner of rref and fetch back the
    result using RPC
    """
    return rpc.rpc_sync(
        rref.owner(),
        _call_method,
        args=[method, rref] + list(args),
        kwargs=kwargs
    )

class Net1(nn.Module):
    def __init__(self):
        super(Net1, self).__init__()
        self.layer = nn.Linear(10, 20)

    def parameter_rrefs(self):
        return [RRef(p) for p in self.parameters() if p.requires_grad]

    def forward(self, x):
        return self.layer(x)

class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.layer = nn.Linear(20, 1)

    def parameter_rrefs(self):
        return [RRef(p) for p in self.parameters() if p.requires_grad]

    def forward(self, x):
        return self.layer(x)

class Net(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Net, self).__init__()
        self.encoder_rref = rpc.remote(
            "worker1",
            Net1,
            args = args,
            kwargs = kwargs
        )

        self.decoder_rref = rpc.remote(
            "worker2",
            Net2,
            args = args,
            kwargs = kwargs
        )

    def parameter_rrefs(self):
        remote_params = []
        remote_params.extend(self.encoder_rref.remote().parameter_rrefs().to_here())
        remote_params.extend(self.decoder_rref.remote().parameter_rrefs().to_here())
        return remote_params

    def forward(self, x):
        x = _remote_method(Net1.forward, self.encoder_rref, x)
        x = _remote_method(Net2.forward, self.decoder_rref, x)
        return x

def run_master():
    model = Net()
    opt = DistributedOptimizer(
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )
    for i in range(10):
        with dist_autograd.context() as context_id:
            x = torch.randn(32, 10)
            loss = model(x)
            loss = loss.sum()
            print('Before dist_autograd')
            dist_autograd.backward(context_id, [loss])
            opt.step(context_id)

    print("finished training")

def workers_init_rpc(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    print(f'Rank {rank}: Proceed to init rpc')
    options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, rpc_timeout = 60*60)
    rpc.init_rpc(
        f"worker{rank}",
        rank=rank,
        world_size=world_size,
        rpc_backend_options=options
    )
    print(f'Rank: {rank}, rpc init done')

    if rank == 0:
        print('Proceed to run_master')
        run_master()
    rpc.shutdown()


if __name__=="__main__":
    world_size = 3
    processes = []

    #options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, rpc_timeout = 60*60)
    #rank = int(sys.argv[1])

    #p = mp.Process(target=workers_init_rpc, args=(rank, world_size, options))

    mp.spawn(workers_init_rpc, args=(world_size, ), nprocs=3, join=True)

    #p.start()

    #processes.append(p)

    #for p in processes:
    #    p.join()

How to run

python <filename>

Output

Rank 2: Proceed to init rpc
Rank 0: Proceed to init rpc
Rank 1: Proceed to init rpc
Rank: 0, rpc init done
Proceed to run_master
Rank: 2, rpc init done
Rank: 1, rpc init done
Before dist_autograd
Before dist_autograd
Before dist_autograd
Before dist_autograd
Before dist_autograd
Before dist_autograd
Before dist_autograd
Before dist_autograd
Before dist_autograd
Before dist_autograd
finished training

@mrshenli I’m wondering if we can somehow report better errors when passing unpicklable objects in torch.multiprocessing? Ideally it seems like this error would’ve been caught earlier instead of manifesting in this confusing way.

@rvarm1 the printed error in my local test is indeed cannot pickle 'TensorPipeRpcBackendOptions' object. In the original code, pickling TensorPipeRpcBackendOptions happens before initializing RPC or gloo. So I suspect @matrix was hitting a different error, but I cannot reproduce that error locally.

Traceback (most recent call last):
  File "tmp1.py", line 129, in <module>
    mp.spawn(workers_init_rpc, args=(world_size, options), nprocs=3, join=True)
  File "/raid/shenli/pytorch/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/raid/shenli/pytorch/torch/multiprocessing/spawn.py", line 179, in start_processes
    process.start()
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
TypeError: cannot pickle 'TensorPipeRpcBackendOptions' object

@mrshenli Above code works with mp.spawn (locally) but not with mp.Process (distributed way) way (Even after moving options init to workers_init_rpc function). Reason I am not using mp.Spawn is, like I said in my previous post I have 3 nodes which has variable number of GPUs (1, 2 & 4). My model is divided into 2 parts. First part reside on a node which has 4 GPUs, and modules of this part are divided (s.t. each GPU holds nearly equal number of parameters) onto these 4 GPUs. Second part reside on the machine which has 2 GPUs, and it’s modules are divided onto 2 GPUs (just like the previous one). Unfortunately, I can’t fit the whole model on a GPU, all of these GPUs have 8 GB VRAM which is not enough for my case. That is why I am using mp.Process methodology to do the training.

Furthermore, node with 1 GPU acts as a master, it encompasses these 2 parts of the model into one. Calls them (using rpc_sync) sequentially to run a full forward pass. Also, this dividing work (copying modules to GPUs) is being done in __init__ (think of Net1 and Net2 as 2 parts of the model) methods, could this be a problem when using mp.spawn?

When using mp.Process master node is successfully able to fetch the results from workers by making rpc_sync calls. Problem comes when executing dist_autograd.backward(context_id, [losses]) on master node, it hangs on this line and due to that ETIMEDOUT error is generated (After moving options line to workers_init_rpc method). This is true for my case (with GPUs) & the code I posted here (to repro. this error).

Note:
PyTorch Version: 1.7.1 (In all 3 of them)

Forgot to mention a observation,

Every time I execute this (on workers), a different port is used (its not constant). Below, port 28380 is used on a particular worker.

Is this okay?

@mrshenli Try to run the code like this,

You might not be able to regenerate this error locally using mp.spawn.

Thank you, all for helping me out.

@mrshenli @rvarm1 @Yanli_Zhao Any updates on this?

Sorry about the delay. If rpc_sync succeeded in your environment, then it mean at least the comm layer is working. So if the backward hangs, I would assume it’s sth wrong with the backward instead of mp.spawn. But let me try Process instead of spawn anyway.

I tried the following two implementations. The first one uses mp.Process to spawn processes, and with the second one, I ran python test.py 0/1/2 in three different terminal tabs. Both work for me. distributed autograd

if __name__=="__main__":
    world_size = 3
    processes = []

    for rank in range(3):
        p = mp.Process(target=workers_init_rpc, args=(rank, world_size))
        p.start()
        processes.append(p)

    [p.join() for p in processes]
if __name__=="__main__":
    world_size = 3

    rank = int(sys.argv[1])
    workers_init_rpc(rank, world_size)

Is the code you shared above exactly the same where you hit the hang problem? There is a known gap in distributed autograd. We currently only support fast mode distributed autograd, which means all RPC comm operations (rpc_sync, rpc_async, remote) must participate in the backward, otherwise the backward would hang. See more details in the doc below.

https://pytorch.org/docs/stable/rpc/distributed_autograd.html

@mrshenli I made some changes like you said, now my code works. However there’s a problem.

On master node I get following output as I have put prints in my code:

Batch forward complete
Epoch: 0, Batch Id: 0,Train Loss: 16.620447158813477
Before dist_autograd (Before executing dist_autograd.backward())
Step done (After optimizer.step() is executed, of course its an instance of DistributedOptimizer)
Batch forward complete (Completion of the forward pass of a batch)
Epoch: 0, Batch Id: 1,Train Loss: 12.148786544799805
Before dist_autograd

It gets stuck here. There’s no error messages on either of the nodes. Here rpc_timeout is set to 1 hr.

Then I changed rpc_timeout to 1 minute, and got the following on master node

Batch forward complete
Epoch: 0, Batch Id: 0,Train Loss: 58.77790832519531
Before dist_autograd
Step done
Batch forward complete
Epoch: 0, Batch Id: 1,Train Loss: 40.98801803588867
Before dist_autograd
Process Process-1:
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/pytorch2/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/user/anaconda3/envs/pytorch2/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "main.py", line 65, in workers_init_rpc
    run_master()
  File "main.py", line 49, in run_master
    train_one_epoch(model, opt, data_loader, epoch, print_freq=10)
  File "/home/user/Documents/Incremental_Learning/demo/helpers/engine.py", line 45, in train_one_epoch
    dist_autograd.backward(context_id, [losses])
RuntimeError: Error on Node 0: RPCErr:1:RPC ran for more than set timeout (60000 ms) and will now be marked with an error

and below on the workers,

[W tensorpipe_agent.cpp:545] RPC agent for worker1 encountered error when reading incoming request from worker0: ECONNRESET: connection reset by peer (this is expected to happen during shutdown)

It does timeout after 1 minute. There’s no problem when processing first batch, it’s smooth doesn’t take time. But when processing second batch dist_autograd hangs, as you can see in above output. This is strange behaviour.

@mrshenli any idea why is it behaving so strangely?

Hey @matrix, sorry about the delay again. We really need to work on our oncall procedure to cover pending discussions not just new discussions.

For the timeout error, the first thing I would check if whether the network is indeed working. To do that, one way is to call rpc_sync (not remote or rpc_async) between all pairs of nodes. If that works, it means the network indeed works. From the timeout message, I cannot tell why the distributed autograd does not work. If you could share your latest code with me, I can grab three machines on AWS and try it.