Pytorch RPC maximum number of concurrent RPCs?

(Running on the latest pytorch nightly)

I am attempting to implement distributed RL training setup with batched inference (similar to Implementing Batch RPC Processing Using Asynchronous Executions — PyTorch Tutorials 1.8.0 documentation). I have working setup, with a small number of RPCs per process (12 processes, with 15 “play_game” RPCs per process active at once).

However, when I attempt to increase the number of games played simultaneously by the worker processes (from 15 to 16 RPCs), instead it freezes, eventually outputting the error

[E thread_pool.cpp:112] Exception in thread pool task: The RPC has not succeeded after the specified number of max retries (5).

hundreds of times after several minutes.

The strange thing is that 15 RPCs per process consistently succeeds, while 16 RPCs per process consistently fails. Is this a limit on the number of RPCs that can be in flight?

The test I am running is available at stone_ground_hearth_battles/test_pytorch_distributed.py at master · JDBumgardner/stone_ground_hearth_battles · GitHub

Hey @jeremysalwen, does increase the value of num_worker_threads help? Distributed RPC Framework — PyTorch master documentation

Its default value is 16.

@mrshenli Yes, that seems to fix the issue. However, I am still puzzled about the original behavior. Running out of worker threads should cause the execution of the RPC to be blocked until the threads are freed, but it seems like instead it is breaking other things.

This sounds like a bug, no?

To simplify my setup, I basically have two processes, process A and process B. Process A is sending 16 RPCs to process B, and all these RPCs are completely independent of each other. Each RPC from A->B internally sends a sequence of RPCs back to A. I don’t see why this should cause a deadlock, even if the thread pool in each process is size 1, because only 1 simultaneous RPC per process is required to make forward progress.

My guess is that the thread pool is being reused to both send and receive RPCs, so all 16 threads are taken by the RPCs from A to B, but now all these threads are stuck, because they are unable to make RPCs from B to A?

This is true for temporary ProcessGroup backend, but I am not 100% if this (send/recv share the same thread pool) is the case for TensorPipe backend. @lcw could you please confirm?

The default RPC functions will block one thread on the callee side until the response is ready. If you would like the response being processed asynchronously, you can decorate the user function with @rpc.functions.async_execution and let the user function return a Future. This should release the thread on callee as soon as it gets the Future object, and can resume running the callback on the Future when the Future is completed.

From what I can remember the TensorPipe agent does indeed only have one thread pool which is shared for many “purposes”, so what you said is totally reasonable.

I don’t immediately recognize the initial error message you got, as it mentions retries, but I don’t think we support retires in “standard” RPC messages. I think we only support them for RRef-related messages and other internal stuff. @mrshenli Is that the case? @jeremysalwen Are you using RRefs, dist autograd, or other such things? Unfortunately the message doesn’t say what the underlying failure is, but I suspect it could be a timeout?

Also note that as @mrshenli said, it’s an antipattern to synchronously block in a remote RPC function or a callback. Doing so will block a thread and eventually lead to starvation. If you must do so, please ensure you’re doing it for a limited number of RPC calls, and size the thread pool accordingly. However, it would be better to use “native” asynchronous patterns.

Yep, this is true. We don’t yet have a good way to retry messages with user functions.

@mrshenli I am using @rpc.functions.async_execution for the calls from B to A, but I don’t see an easy way to do so for the RPC from A to B. Remember, internally the RPC from A to B then calls back to B multiple times, so I would need to suspend/resume the thread in order to free it. (i.e. I would need something like asyncio, which would much further complicate things).

@lcw I am using RRefs, but not dist autograd, or any other advanced features. I’m not sure how I would transform my code to be non-blocking. Fundamentally, my RPC running on process B needs to wait until its RPC to process A completes before continuing execution. I would expect that pytorch distributed should recognize this situation, and return the thread to the thread pool while it is waiting on the RPC. What could I change in my code so that it “returns the thread” to the thread pool, but continues executing once the RPC completes?

I would expect that pytorch distributed should recognize this situation, and return the thread to the thread pool while it is waiting on the RPC.

Unfortunately that’s not something that can easily be achieved. It can be done through Python’s asyncio package but that’s not supported by PyTorch (for historical reasons I guess, and it’s hard now to retrofit it). Any other alternative would basically consist in reimplementing our own asyncio, and that’s just unreasonable.

What could I change in my code so that it “returns the thread” to the thread pool, but continues executing once the RPC completes?

I think your situation could be addressed by using the async_execution decorator, and futures, as @mrshenli already mentioned. More concretely, here is what such a code could look like:

# Imagine this is the original code:
def my_function(a):
    b = foo(a)
    c = rpc.rpc_sync("other_worker", my_other_function, args=(b,))
    d = bar(c)
    return d

# It could become like this:
@rpc.functions.async_execution
def my_function(a):
    b = foo(a)
    fut_c = rpc.rpc_async("other_worker", my_other_function, args=(b,))
    fut_d = fut_c.then(lambda fut: bar(fut.value()))
    return fut_d

For reference, I ultimately decided to directly integrate asyncio with pytorch RPC to get around this issue. I describe how I did it in this post: Pytorch Distributed RPC bottleneck in _recursive_compile_class - #9 by jeremysalwen

This idea of using @rpc.functions.async_execution is interesting, but to me the example looks like it’s begging to be a coroutine instead of a pair of functions chained together with a callback :slight_smile: