Key-value store shared memory with tensors

Hey y’all,

I had a quick question about making a distributed pytorch application I had built more efficient. Basically, I am implementing a full-torch version of Ape-X (https://arxiv.org/pdf/1803.00933.pdf). In it, the authors implement a shared replay buffer in shared memory with a tensorflow key-value store (Appendix F).

My issue occurs that, since I can’t really specify how much compute each actor uses (I’m assuming one device per each item in world_size), my CPUs just get destroyed with handling all these RPCs, and my hope was to build this system to scale well (independent of the number of actors).

My question–are there ops such as tensorflow’s lookup module for this stuff? Can I put tensors in shared memory? I understand the TCPStore only takes strings in its set() args. Is there a torch recommendation for handling the case where I have centralized data that needs to get read by a learner and added to by a bunch of actors generating data, without requests overload?

Thanks for any time y’all.

Hi theoryofjake,

Can I put tensors in shared memory?

You can move tensors to shared memory torch.Tensor.share_memory_ — PyTorch 2.1 documentation.

A possible solution is to use a python dictionary and move the values to shared memory. Will this work for your case?

Could you give an example? My thought was something similar to using torch.distributed’s TCPStore, something like

store = torch.distributed.TCPStore(..)
rand_tensor = torch.rand((64, 10))
store.set('rand_tensor1', rand_tensor)

but this errors because set() only takes str as the second arg.

example using tensor.share_memory_()

import torch
import torch.multiprocessing as mp

def run(rank, python_dict):
    python_dict[rank] += torch.randn(2)

def run_example(share_memory):
    print(f"share_memory={share_memory}")
    nproc=4
    python_dict = {}
    for i in range(nproc):
        python_dict[i] = torch.zeros(2)
        if share_memory:
            python_dict[i].share_memory_()
    print(f"before={python_dict}")
    processes = []
    for rank in range(nproc):
        p = mp.Process(target=run, args=(rank, python_dict,))
        processes.append(p)
    for proc in processes:
        proc.start()
        proc.join()
    print(f"after={python_dict}\n")

if __name__ == "__main__":
    run_example(share_memory=False)
    run_example(share_memory=True)