Increasing memory usage when using the torch.multiprocessing

Hi,

I’m using the torch.multiprocessing module for paralleling model inference. For my own understanding, CUDA tensor will be placed in a shared memory instead of kept by each process. So the memory will not be doubled if I spawn new process using torch.multiprocessing. But when I run the code followed, the memory will increase very fast.

import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.multiprocessing as mp 
import time 

mp1 = mp.get_context('fork')
mp2 = mp.get_context('spawn')

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10,128)
        self.fc2 = nn.Linear(128,128)
        self.fc3 = nn.Linear(128,10)

    def forward(self, x):
        x = self.fc1(x) 
        x = self.fc2(x) 
        x = self.fc3(x) 
        return x

class Work1(mp1.Process):
    def __init__(self, a):
        mp.Process.__init__(self)
        self.a = a

    def run(self):
        while True:
            pass

class Work2(mp2.Process):
    def __init__(self, a):
        mp2.Process.__init__(self)
        self.a = a

    def run(self):
        while True:
            pass

def main():
    p = Net().cuda()
    p.share_memory()
    for i in range(10):
        w = Work1(p)
        w.start()
        print('start a work1')
    time.sleep(10)
    for i in range(10):
        w = Work2(p)
        w.start()
        print('start a work2')



if __name__ == '__main__':
    main()

I know the reason why fork does not increasing the memory usage immediately(COW), but could anyone tell me why spawn will just doubled the memory I used? Or why the Net object will occupy so much memory?