How sharing_memory actually worked in pytorch

Sorry to bother you, I recently work with torch.multiprocessing and have some problems,hope you can help me😶‍🌫️

  1. If I shared a model in cuda, it raises
    RuntimeError: Attempted to send CUDA tensor received from another process; this is not currently supported. Consider cloning before sending.
  2. torch.multiprocessing.manager.queue.get taken a long time to finish. If the queue just passed a file descriptor, I don’t think it should take 1/3 of the total time, is there any faster way?
    Here’s my script
import torch

import torch.multiprocessing as mp

from copy import deepcopy

from functools import partial

from time import *

from torchvision import models

import numpy as np

from tqdm import tqdm

def parallel_produce(

    queue: mp.Queue,

    model_method,

    i

) -> None:

    pure_model: torch.nn.Module = model_method()

    # if you delete this line, model can be passed
    pure_model.to('cuda')

    pure_model.share_memory()

    while True:

        corrupt_model = deepcopy(pure_model)

        dic = corrupt_model.state_dict()

        dic[list(dic.keys())[0]]*=2

        corrupt_model.share_memory()

        queue.put(corrupt_model)

def parallel(

    valid,

    iteration: int = 1000,

    process_size: int=2,

    buffer_size: int=2

):

    pool = mp.Pool(process_size)

    manager = mp.Manager()

    queue = manager.Queue(buffer_size)

    SeedSequence = np.random.SeedSequence()

    model_method = partial(models.squeezenet1_1,True)

    async_result = pool.map_async(

        partial(

            parallel_produce,

            queue,

            model_method,

        ),

        SeedSequence.spawn(process_size),

    )

    time = 0

    for iter_times in tqdm(range(iteration)):

        start = monotonic_ns()

        # this takes a long time

        corrupt_model: torch.nn.Module = queue.get()

        time += monotonic_ns() - start

        corrupt_model.to("cuda")

        corrupt_result = corrupt_model(valid)

        del corrupt_model

    pool.terminate()

    print(time / 1e9)

if __name__ == "__main__":

    valid = torch.randn(1,3,224,224).to('cuda')

    parallel(valid)

#total time of queue.get taken

1 Like