The following code works perfectly on CPU. On CUDA, the second print shows that the weights are all 0.
If I don’t pass l
to the pool, it works. If I replace the pool from concurrent.futures
with mp.Process
weights are still 0.
This happens only on CUDA.
What am I doing wrong?
Python 3.10.9
PyTorch 2.0.0
CUDA 11.7
import torch
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from torch import multiprocessing as mp
def goo(l):
return l(torch.rand(2, device='cuda'))
def run():
ctx = mp.get_context('spawn')
l = torch.nn.Linear(2, 2).to('cuda').share_memory()
print(vars(l))
pool = ProcessPoolExecutor(1, mp_context=ctx)
pool.submit(goo, l)
def foo():
print(vars(l))
thread = ThreadPoolExecutor(1)
thread.submit(foo)
if __name__ == '__main__':
run()