Can RPC leverage multicore?

I am using torch.distributed.rpc. I can set the rpc to have many threads using rpc_backend_options, but it seems like it is not being mapped onto idle CPUs that I have.
Specifically, to test out, I’ve sent 1-4 asynchronous RPC calls to a server which has 80 CPUs.
Below is the code for reference.

import os
import time
import torch
import torch.nn as nn
import numpy as np
from torch.multiprocessing import Process
import torch.distributed.rpc as rpc

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l0 = nn.Linear(2, 2)
        W = np.random.normal(0, 1, size=(2,2)).astype(np.float32)
        self.l0.weight.data = torch.tensor(W, requires_grad=True)
        self.l1 = nn.Linear(2, 2)
        W = np.random.normal(0, 1, size=(2,2)).astype(np.float32)
        self.l1.weight.data = torch.tensor(W, requires_grad=True)

    def forward(self, x):
        return self.l1(self.l0(x))

def test(t):
    print("RPC called")
    for i in range(100000):
        t2 = t*1.000001
    return t2

def run(i):
    rpc.init_rpc("Rank"+str(i), rank=i, world_size=2)
    if i == 0:
        with torch.autograd.profiler.profile(True, False) as prof:
            net = Net()
            optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
            input = torch.tensor([1.0, 2.0])
            reqs = []
            reqs.append(rpc.rpc_async("Rank1", test, args=(input,)))
            reqs.append(rpc.rpc_async("Rank1", test, args=(input,)))
            reqs.append(rpc.rpc_async("Rank1", test, args=(input,)))
            reqs.append(rpc.rpc_async("Rank1", test, args=(input,)))
            #reqs.append(rpc.rpc_async("Rank1", test, args=(input,)))
            for req in reqs:
                input += req.wait()
            print("RPC Done")
            y = net(input)
            optimizer.zero_grad()
            y.sum().backward()
            optimizer.step()
        print(prof.key_averages().table(sort_by="cpu_time_total"))
        prof.export_chrome_trace("test.json")
    else:
        pass

    rpc.shutdown()

if __name__ == "__main__":
    os.environ['MASTER_ADDR'] = "localhost"
    os.environ['MASTER_PORT'] = "29500"
    ps = []
    for i in [0, 1]:
        p = Process(target=run, args=(i,))
        p.start()
        ps.append(p)

    for p in ps:
        p.join()

As you can see, I am just doing some compute-intensive work on the server using RPC.
Below is the result from my profiler.
When I do 1 RPC call:


When I do 4 RPC calls:

Default RPC init makes 4 send_recv_threads. So it should be able to “concurrently” run my 4 RPC requests. However, as you can see, the time to finish the RPC requests grew almost linearly (from 460ms to 2200ms) with 4 requests, meaning that they are using only one core and are not being processed in parallel (i.e., concurrent, but not parallel).

I know that python threads (unlike processes) cannot execute in parallel on different cores. Is RPC threads also (because they are threads) cannot run in parallel on different cores?
Is there a way to run different RPC request received on different cores? Or should I manually spawn processes in the receiving server side to run the requests in parallel and leverage my multicore server?

Thank you.

1 Like

This is likely due to Python GIL on the server side. Can you torchscript (https://pytorch.org/docs/stable/jit.html) the test method and try again? That should avoid GIL.

1 Like

Probably Yes for torch jit scripts, and No for regular python functions.

I will describe my understanding of the framework in details. For c++ files, their root directory is (git master) torch/csrc/distributed.

Internally, torch rpc framework does the following things:

  1. entry point: rpc.async, rpc.sync use the same _invoke_rpc() while rpc.remote use its own implementation. For these three methods, they all categorize your calls into three categories: builtin is for python builtin function, udf is for user defined functions, jit is for torch script. Then all three lower level calls go into the torch C library.

  2. c library: the c-python interface is defined in rpc/init.cpp, which uses pybind11 to define the interfaces, on calling the interface, call guards (constructed before wrapped functions) are gil_scoped_release so GIL is released here for all three interface categories. The wrapped functions are defined in rpc/python_functions.cpp they will find your target rpc process (c++ class RpcAgent) and send your message to it. GIL will be reaquired after finishing the c function call due to the automatic deconstruction of call gaurds.

  3. rpc agent: RpcAgent class will initialize cb_ member on construction, which is a unique_ptr of type RequestCallback, there are two derived classes of RpcAgent: TensorPipeAgent and ProcessGroupAgent. In the rpc use case you are dealing with ProcessGroupAgent, in its member function handleRecv() it will use the cb_member to handle the call back. All these agents are defined in rpc/<agent_name_lower_case>.cpp so you should be able to find them easily.

  4. request callback: Request callback is an abstract functor defined in request_callback.h and request_callback.cpp, it has a virtual processMessage() method. Its real implementation in request_callback_impl.cpp defines this method and calls its member method processRpc().

  5. process rpc: This function handles rpc calls based on its message type, SCRIPT_CALL and SCRIPT_REMOTE_CALL will handle jit scripted calls, however in PYTHON_CALL:

      {
           py::gil_scoped_acquire acquire;
           serializedPyObj =
           std::make_shared<SerializedPyObj>(pythonRpcHandler.serialize(
           pythonRpcHandler.runPythonUdf(std::move(upc).movePythonUdf())));
      }
    

There is a py::gil_scoped_acquire, according to the pybind11 definition, this will hold the gil lock until acquire is deconstructed (aka leaving this c++ scope), so Nope, you cannot leverage multicore by using multi-threads used in rpc.

Note this explanation is valid for commit hash 176174a68ba2d36b9a5aaef0943421682ecc66d4 and release till 1.5.0, as I can see in their source code that they are planning to further abstract away the switch case in processRpc() to an abstract execute() method of RpcCommandBase

1 Like

So your code is theoretically equivalent to using the ThreadPool(4).map from multiprocessing library. The test code is as follows:

import torch as t
from multiprocessing.pool import ThreadPool
from utils.helper_classes import Timer

def test(t):
    print("RPC called")
    for i in range(100000):
        t2 = t*1.000001
    return t2

def test1():
    tm = Timer()
    ts=t.Tensor([1,2])
    tm.begin()
    for i in range(4):
        test(ts)
    print(tm.end())

def test2():
    tm = Timer()
    ts = t.Tensor([1,2])
    pool = ThreadPool(4)
    tm.begin()
    pool.map(test, (ts, ts, ts, ts))
    print(tm.end())

And my test result is:
test1: 2.091 s
test2: 2.404 s

You can dispatch sub processes in the rpc call to work around the GIL.

1 Like