Fetching data from remote server in dataloader

I have a large hd5 file (~100GB) containing image features from resnet. This file is located on my local machine (laptop). My model is trained on cluster node that has storage limit of 25GB.
Right now, I am using torch.distributed.rpc for tranferring data from my local machine to cluster.
I am exposing a server on my local machine in the following way,

num_worker = 4
utils.WORLD_SIZE = num_worker + 1

import os
import torch
import utils
import torch.distributed.rpc as rpc

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8888'
                 rank = rank, 
                 world_size = world_size)
    print("Server Initialized", flush=True)

if __name__ == "__main__":
    rank = 0
    world_size = utils.WORLD_SIZE
    run_worker(rank, world_size)

This server sends data from local machine to cluster. (other classes are omitted)

Now for requesting data from cluster, I am initializing rpc for each worker using worker_init_fn for dataloader,

def worker_init_fn(worker_id):
                rank=worker_id+1, world_size=utils.WORLD_SIZE)
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset
    worker_id = worker_info.id
    server_info = rpc.get_worker_info(utils.SERVER_NAME)
    dataset.server_ref = rpc.remote(server_info, utils.Server)

Now, when I run my code, the training loop completes one iteration of the dataset and hangs after that and I get the following error on cluster side,

Traceback (most recent call last):
  File "custom_datasets.py", line 134, in <module>
  File "custom_datasets.py", line 110, in main
    for i, (images, labels) in enumerate(mn_dataset_loader):
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 135, in _worker_loop
  File "custom_datasets.py", line 78, in worker_init_fn
    rank=worker_id+1, world_size=utils.WORLD_SIZE)
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/distributed/rpc/__init__.py", line 67, in init_rpc
    store, _, _ = next(rendezvous_iterator)
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/distributed/rendezvous.py", line 168, in _env_rendezvous_handler
    store = TCPStore(master_addr, master_port, world_size, start_daemon)
RuntimeError: connect() timed out.

Above problem doesn’t occur when I set num_worker = 0, but the cluster code is very slow. I think the error is because of multi-threading but I am not sure how to resolve this. Please help me resolving the issue.

Hey @Kanishk_Jain

Sorry about the late response.

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8888'

If you are trying to use the distributed package across multiple machines, the MASTER_ADDR needs to be a routable IP, and localhost is not a routable one. One way to test if the specified IP works is to try if torch.distributed.init_process_group can successfully run.

BTW, could you please post followup questions under the “distributed-rpc” tag? The PT Distributed team subscribes to updates with “distributed” and “distributed-rpc” tags to make sure to join the discussions promptly.