Deadlock with torch.distributed.rpc with num_workers > 1

I have a large (93GB) .h5 file containing image features on my local system and my model is trained on SLURM ADA cluster which has a storage limit of 25GB.
I am trying to use torch.distributed.rpc framework for requesting image features in Dataset.getitem using remote call to rpc server on my local system.

Code for initializing RPC server (local system):

import os
import torch.distributed.rpc as rpc

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'XX.X.XX.XX'
    os.environ['MASTER_PORT'] = 'XXXX'
    
    rpc.init_rpc(utils.SERVER_NAME,
                 rank=rank, 
                 world_size=world_size)
    print("Server Initialized", flush=True)

    rpc.shutdown()


if __name__ == "__main__":
    world_size = 2
    rank = 0
    
    run_worker(rank, world_size)

Code for RPC server for requesting data from local system (On ADA),

import os
import torch.distributed.rpc as rpc

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'XX.X.XX.XX'
    os.environ['MASTER_PORT'] = 'XXXX'

    rpc.init_rpc(utils.CLIENT_NAME.format(rank), rank=rank, world_size=world_size)
    print("Client Initialized", flush=True)

    main()

    rpc.shutdown()

if __name__ == '__main__':
    world_size = 2
    rank = 1

    run_worker(rank, world_size)

In my data_loader I have specified num_worker=8,
Simplified code for dataset.getitem is (On ADA),

def __getitem__(self, index):
        ....
        ....

        print("fetching image for image_id {}, item {}".format(image_id, item), flush=True)
        v = utils._remote_method(utils.Server._load_image, self.server_ref, [index, self.coco_id_to_index])

        return v, ......

Now in my training loop when I call enumerate(data_loader), multi-process data loading is enabled and getitem function is called num_worker times and a deadlock is reached.
I am not sure why this deadlock is occuring because whenver getitem is called a remote call should be made to RPC server on my local system to request for data.

How can I resolve the deadlock? Is there any other way to solve my problem if large file doesn’t fit in my ADA system, I don’t want to compromise on my latency.

Edit: When I set num_worker=0 my code is working, but it is very slow 20 sec/iterations.

Hey @Kanishk_Jain

Thanks for trying out RPC.

multi-process data loading is enabled and getitem function is called num_worker times and a deadlock is reached.

It could be because it depleted the RPC threads in the thread pool. num_send_recv_threads by default is 4. Does it work if you bump up the number of threads? Sth like:

import torch
from torch.distributed.rpc import ProcessGroupRpcBackendOptions
from datetime import timedelta
import os
import torch.distributed.rpc as rpc

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'

options = ProcessGroupRpcBackendOptions()
options.rpc_timeout = timedelta(seconds=60)
options.init_method = "env://"
options.num_send_recv_threads = 32

rpc.init_rpc("client", rank=0, world_size=2, rpc_backend_options=options)
rpc.shutdown()

The current rpc_backend_options API is too verbose and not well documented. We will improve that in the next release.

Regarding the concern on speed, hopefully, using more workers in the data loader will help to boost the throughput. We are also working on making the RPC comm layer more efficient by adding TensorPipe as a new backend, so that RPC does not have to do two round trips for each message as with ProcessGroup and TensorPipe would also allow using multiple comm media.

1 Like