I would like to use multiprocessing to launch multiple training instances on CUDA device. Since the data is common between the processes, I want to avoid data copy for every process. I’m using python 3.8’s
SharedMemory from multiprocessing module to achieve this.
I can allocate a memory block using SharedMemory and create as many processes as I’d like with constant memory (RAM) usage. However, when I try to send tensors to CUDA, the memory scales linearly with the number of processes. It appears as if when
c.to(device) is called, the base data is copied for every process.
Does any one know why this is happening? Any ideas to mitigate this issue?
Here is the sample code I’m using:
import numpy as np from multiprocessing import shared_memory, get_context import time import torch import copy dim = 10000 batch_size = 10 sleep_time = 2 npe = 1 # number of parallel executions # cuda if torch.cuda.is_available(): dev = 'cuda:0' else: dev = "cpu" device = torch.device(dev) def step(i, shr_name): existing_shm = shared_memory.SharedMemory(name=shr_name) np_arr = np.ndarray((dim, dim), dtype=np.float32, buffer=existing_shm.buf) b = np_arr[i * batch_size: (i + 1) * batch_size, :] b = torch.Tensor(b) # This is just to explicitly copy the tensor so that it has nothing to do # with the shared memory block c = copy.deepcopy(b) # If tensor c is sent to the cuda device, then RAM scales linearly # with the number of parallel executions. # If c is not sent to cuda device, memory consumption is constant. c = c.to(device) time.sleep(sleep_time) existing_shm.close() def create_shared_block(): a = np.random.random((dim, dim)).astype(np.float32) shm = shared_memory.SharedMemory(create=True, size=a.nbytes, name='sha') np_arr = np.ndarray(a.shape, dtype=np.float32, buffer=shm.buf) np_arr[:] = a[:] return shm, np_arr if __name__ == '__main__': # create shared memory block shm, np_arr = create_shared_block() # create list of inputs to be executed in parallel inp = [[x, 'sha'] for x in range(npe)] print(inp) # sleep added before and after launching multiprocessing to monitor the memory consumption print('before pool') # to check memory with top or htop time.sleep(sleep_time) context = get_context('spawn') with context.Pool(npe) as pool: print('after pool') # to check memory with top or htop time.sleep(sleep_time) pool.starmap(step, inp) time.sleep(sleep_time) shm.close() shm.unlink()