Remove the dependency of Gloo from RPC

Hi there, my distributed RPC project have ~500 processes and the Gloo initialization will build a TCP connection between each pair and result in “Address already in used” error.

In old version of PyTorch, I found the only place of using Gloo in RPC is the function “_tensorpipe_init_backend_handler” of the file “torch/distributed/rpc/backend_registry.py”. Since the only usage of Gloo inside RPC is to ensure rpc.shutdown works properly, when my project uses an infinite loop without safe termination logic, it is safe to absolutely remove Gloo. I was able to “hack” the function “_tensorpipe_init_backend_handler” and commented out all lines related to Gloo, and it worked very well.

But in the newest version of PyTorch, the initialization method of TensorPipeAgent requires gloo “group” as an argument, which makes this hack method no longer available. I am wondering is there any alternative hacking way to get rid of the gloo group initialization. Thanks!

Thanks for the question, Levi_Ackerman! @H-Huang what do you think? I see that there’s code for dynamic RPC in that function as well which seems to not use gloo but I’m not sure if it’s ready for use?

Hi @Levi_Ackerman, ~500 processes seems like a “titan” problem :slight_smile:. Curious about which version of PyTorch you are using? I believe that ProcessGroup has always been a part of TensorPipeAgent initialization since 1.7.

In PT 1.7, TensorPipeAgent added ProcessGroup(Gloo) in it’s initialization which requires Gloo to be initialized prior to TensorPipeAgent initialization
[TensorPipe] Implement join correctly (#38933) · pytorch/pytorch@54046c1 · GitHub.

In PT 1.11, We removed the dependency of ProcessGroup from TensorPipeAgent initialization, this means that the shutdown of TensorPipeAgent does not depend on ProcessGroups, however, ProcessGroup are still used before tensor pipe agent initialization to exchange device_map / device information. Remove ProcessGroup from TensorPipeAgent initialization by H-Huang · Pull Request #66708 · pytorch/pytorch · GitHub

In PT 1.12, we introduce dynamic RPC as a prototype, which allows for an optional world_size argument, allows nodes to join and leave an RPC group, and completely removes process group for dynamic groups in both startup and shutdown.

As a result, I think you have two options:

  1. Use PT 1.11+, TensorPipeAgent does not include a ProcessGroup in its constructor, but you’ll need to hack the backend_registry.py to not create a ProcessGroup when the rpc agent is being bootstrapped.
  2. Use PT 1.12 and the dynamic / elastic RPC feature, however as @aazzolini mentioned it is still prototype and needs more investment.

If you could tell me your pytorch version and give a small code snippet of what you have, then I can also provide further advice. Thanks!

2 Likes

Hi @aazzolini @H-Huang , thank you so much for your responses!

Yes. ~500 processes is a “titan” project and I am supposed to write a custom communication protocol. But since RPC provides a super convenient interface and highly optimized communication, I chose not to re-invent the wheel. Despite large number of processes, the communication pairs increase linearly on the number of processes, since the majority of data transfer is done by one-to-many topology instead of all-to-all used by Gloo.

Yes. I knew the Gloo group is also used for exchanging device_map information. But since in most cases I will use the default setting for device, I hacked the code and set “reversed_device_maps” to an empty dict and “devices” to an empty list.

I am using PT 1.10.1 since my lab cluster only has CUDA version 11.1. I have installed newer version of PyTorch on it but it behaved weirdly. Here is my hacking code from PT 1.10.1 that was working before. I removed all the original comment codes for simplicity.

def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options):
    from . import TensorPipeRpcBackendOptions
    from . import TensorPipeAgent

    if not isinstance(store, dist.Store):
        raise TypeError("`store` must be a c10d::Store. {}".format(store))

    if not isinstance(
        rpc_backend_options, TensorPipeRpcBackendOptions
    ):
        raise TypeError(
            "`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {}".format(
                rpc_backend_options
            )
        )

    # hacking, remove the dependency of Gloo group
    '''
    group = _init_process_group(store, rank, world_size)
    '''

    if torch.cuda.is_available():
        torch.cuda.init()
        device_count = torch.cuda.device_count()
    else:
        device_count = 0

    # hacking, we use the default setting and there is no need to exchange device_map information
    '''
    reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(
        name,
        device_count,
        rpc_backend_options.device_maps,
        rpc_backend_options.devices,
        group,
    )
    '''

    # hacking, in PT 1.10.1 group is required for the argument of TensorPipeAgent
    agent = TensorPipeAgent(
        store,
        name,
        rank,
        world_size,
        group,
        rpc_backend_options,
        # hacking the following two arguments
        '''
        reverse_device_maps,
        devices,
        '''
        {},
        [],
    )

    api._init_rpc_states(agent)

    # hacking, we don't need the following two lines since we will never call rpc.shutdown()
    # hacking, replace each occurrence of rpc.shutdown() with something like time.sleep(1 << 31)
    '''
    api._all_gather(None, timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC)
    group.barrier().wait()
    '''

    return agent

Also, it would be beneficial to assign various number of threads to different RPC processes. The master process, which serves as the center node for one-to-all communication, requires more threads than worker processes. The current setting, which sets the fixed maximum threads to all processes, would easily result in overwhelming total number of threads for a single node.

Thanks for the reply @Levi_Ackerman!

Your code looks fine, then the only way to continue using PT 1.10.1 along with your hack would be to create a stub process group to use as a placeholder like below, since PG is not used in rpc besides shutdown:

import torch
class ProcessGroupStub(torch.distributed.ProcessGroup):
    def __init__(self):
        super().__init__(-1, -1)

group = ProcessGroupStub()

Ideally if you can figure out how to update your clusters CUDA version and use PT 1.11 that would be better.

Also, it would be beneficial to assign various number of threads to different RPC processes. The master process, which serves as the center node for one-to-all communication, requires more threads than worker processes. The current setting, which sets the fixed maximum threads to all processes, would easily result in overwhelming total number of threads for a single node.

This can already be done by setting the num_worker_threads on the master node (rank 0) separately, e.g.

# on rank 0
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
    num_worker_threads=128,
)
init_rpc("master", rank=0, world_size=n, rpc_backend_options=rpc_backend_options)

# on rank 1+, use default backend options
init_rpc("worker1", rank=1, world_size=n)
1 Like

@H-Huang Thank you so much for your solution!

Hi @H-Huang ! I have tried your method about the stub process.

Your code looks fine, then the only way to continue using PT 1.10.1 along with your hack would be to create a stub process group to use as a placeholder like below, since PG is not used in rpc besides shutdown:

However, a call for torch.distributed.ProcessGroup(-1, -1) resulted in TypeError: torch._C._distributed_c10d.ProcessGroup: No constructor defined!. Any clue about this? Thanks!

Hey @Levi_Ackerman, my bad. It looks like the Python Process Group extensibility was added here: Add pybind trampoline for ProcessGroup and Work by mrshenli · Pull Request #66338 · pytorch/pytorch · GitHub. I believe this may have been released in PT 1.11 rather than PT 1.10.1