I have a large (93GB) .h5 file containing image features on my local system and my model is trained on SLURM ADA cluster which has a storage limit of 25GB.
I am trying to use torch.distributed.rpc framework for requesting image features in Dataset.getitem using remote call to rpc server on my local system.
Code for initializing RPC server (local system):
import os
import torch.distributed.rpc as rpc
def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'XX.X.XX.XX'
os.environ['MASTER_PORT'] = 'XXXX'
rpc.init_rpc(utils.SERVER_NAME,
rank=rank,
world_size=world_size)
print("Server Initialized", flush=True)
rpc.shutdown()
if __name__ == "__main__":
world_size = 2
rank = 0
run_worker(rank, world_size)
Code for RPC server for requesting data from local system (On ADA),
import os
import torch.distributed.rpc as rpc
def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'XX.X.XX.XX'
os.environ['MASTER_PORT'] = 'XXXX'
rpc.init_rpc(utils.CLIENT_NAME.format(rank), rank=rank, world_size=world_size)
print("Client Initialized", flush=True)
main()
rpc.shutdown()
if __name__ == '__main__':
world_size = 2
rank = 1
run_worker(rank, world_size)
In my data_loader I have specified num_worker=8,
Simplified code for dataset.getitem is (On ADA),
def __getitem__(self, index):
....
....
print("fetching image for image_id {}, item {}".format(image_id, item), flush=True)
v = utils._remote_method(utils.Server._load_image, self.server_ref, [index, self.coco_id_to_index])
return v, ......
Now in my training loop when I call enumerate(data_loader), multi-process data loading is enabled and getitem function is called num_worker times and a deadlock is reached.
I am not sure why this deadlock is occuring because whenver getitem is called a remote call should be made to RPC server on my local system to request for data.
How can I resolve the deadlock? Is there any other way to solve my problem if large file doesn’t fit in my ADA system, I don’t want to compromise on my latency.
Edit: When I set num_worker=0 my code is working, but it is very slow 20 sec/iterations.