The example posted in the original post seems to produce the correct behavior on WSL2, CUDA12.2, Python 3.10.12 and PyTorch 2.0.1. However, the below produces very odd behavior. It seems placing the module in shared memory resets the variables.
import torch
import threading
from torch import multiprocessing as mp
def goo(l):
print("entering goo")
print(list(l.parameters()))
x = l(torch.rand(2, device='cuda'))
print(x)
print("leaving goo")
def run():
ctx = mp.get_context('spawn')
l = torch.nn.Linear(2, 2).to('cuda').share_memory()
print("##### regular call #####")
goo(l)
print("##### thread call #####")
thread = threading.Thread(target=goo, args=(l, ))
thread.start()
thread.join()
print("##### process call #####")
process = ctx.Process(target=goo, args=(l,))
process.start()
process.join()
print("##### thread call #####")
thread = threading.Thread(target=goo, args=(l, ))
thread.start()
thread.join()
print("##### regular call #####")
goo(l)
if __name__ == '__main__':
run()
outputs
##### regular call #####
entering goo
[Parameter containing:
tensor([[ 0.2169, -0.5305],
[ 0.1662, -0.3463]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.4492, 0.0421], device='cuda:0', requires_grad=True)]
tensor([-0.9618, -0.2915], device='cuda:0', grad_fn=<AddBackward0>)
leaving goo
##### thread call #####
entering goo
[Parameter containing:
tensor([[ 0.2169, -0.5305],
[ 0.1662, -0.3463]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.4492, 0.0421], device='cuda:0', requires_grad=True)]
tensor([-0.7600, -0.1606], device='cuda:0', grad_fn=<AddBackward0>)
leaving goo
##### process call #####
entering goo
[Parameter containing:
tensor([[0., 0.],
[0., 0.]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0., 0.], device='cuda:0', requires_grad=True)]
tensor([0., 0.], device='cuda:0', grad_fn=<AddBackward0>)
leaving goo
##### thread call #####
entering goo
[Parameter containing:
tensor([[0., 0.],
[0., 0.]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0., 0.], device='cuda:0', requires_grad=True)]
tensor([0., 0.], device='cuda:0', grad_fn=<AddBackward0>)
leaving goo
##### regular call #####
entering goo
[Parameter containing:
tensor([[0., 0.],
[0., 0.]], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0., 0.], device='cuda:0', requires_grad=True)]
tensor([0., 0.], device='cuda:0', grad_fn=<AddBackward0>)
leaving goo