Is there a way to send location of pytorch tensor in gpu memory between docker containers and build them in different containers

To quickly sum up the problem, I need to transfer images (size is (1920,1200,3)) between PyTorch docker containers and process them. Containers are located in the same system. Speed is very important and transfer should not take more than 2-3ms one way. Two containers will be shared via IPC so I find no problem transferring NumPy arrays via shared memory using buffers (example multiprocessing.shared_memory — Shared memory for direct access across processes — Python 3.10.5 documentation). I am curious is there a similar way to do that with PyTorch tensors allocated on GPU?

From what I’ve learned, CUDA Tensors are already in the shared memory. I tried transferring them and Pytorch Tensor Storage objects via socket but it takes around 50-60ms one way, which is way too slow. For testing purposes, I just run 2 programs in separate terminals.

Container 1 code:

import torch
import zmq

def main():
    ctx = zmq.Context()
    sock = ctx.socket(zmq.REQ)
    sock.connect('tcp://0.0.0.0:6000')

    x = torch.randn((1, 1920, 1200, 3), device='cuda')
    storage = x.storage()

    while True:
        sock.send_pyobj(storage)
        sock.recv()

if __name__ == "__main__":
    main()

Container 2 code:

import torch
import zmq
import time

def main():
    ctx = zmq.Context()
    sock = ctx.socket(zmq.REP)
    sock.bind('tcp://*:6000')

    for i in range(10):
        before = time.time()
        storage = sock.recv_pyobj()

        tensor = torch.tensor((), device=storage.device)
        tensor.set_(storage)

        after = time.time()
        print(after - before)

        sock.send_string('')

if __name__ == "__main__":
    main()

I found a similar topic discussed 4 years ago. There person extracts additional information from storage using _share_cuda_() function, which gives cudaIpcMemHandle_t.

Is there a way to reconstruct Storage/Tensor using cudaIpcMemHandle_t or information extracted from _share_cuda_() function using Pytoch functional? or there is a better way to achieve same result?

I found a function in torch.multiprocessing.reductions that rebuilds tensors from the output generated by _share_cuda_(). Now my code looks something like this:

Container 1 code:

import torch
import zmq

def main():
    ctx = zmq.Context()
    sock = ctx.socket(zmq.REQ)
    sock.connect('tcp://0.0.0.0:6000')

    image = torch.randn((1, 1920, 1200, 3), dtype=torch.float, device='cuda:0')
    storage = image.storage()
    
    (storage_device, storage_handle, storage_size_bytes, storage_offset_bytes,
    ref_counter_handle, ref_counter_offset, event_handle, event_sync_required) = storage._share_cuda_()

    while True:
        sock.send_pyobj({
            "dtype": image.dtype,
            "tensor_size": (1920, 1200, 3),
            "tensor_stride": image.stride(),
            "tensor_offset": image.storage_offset(), # !Not sure about this one.
            "storage_cls": type(storage),
            "storage_device": storage_device,
            "storage_handle": storage_handle,
            "storage_size_bytes": storage_size_bytes,
            "storage_offset_bytes": storage_offset_bytes,
            "requires_grad": False,
            "ref_counter_handle": ref_counter_handle,
            "ref_counter_offset": ref_counter_offset,
            "event_handle": event_handle,
            "event_sync_required": event_sync_required,
        })

        sock.recv_string()

if __name__ == "__main__":
    main()

Container 2 code:

import torch
import zmq
import time
from torch.multiprocessing.reductions import rebuild_cuda_tensor


def main():
    ctx = zmq.Context()
    sock = ctx.socket(zmq.REP)
    sock.bind('tcp://*:6000')

    for i in range(10):
        before = time.time()

        cuda_tensor_info = sock.recv_pyobj()
        rebuilt_tensor = rebuild_cuda_tensor(torch.Tensor, **cuda_tensor_info)

        after = time.time()
        print(after - before)

        sock.send_string('')

if __name__ == "__main__":
    main()