Hi, when I want to get data from dataloader and pass it to subprocess, my model and subprocess will block. But in subprocess create dataloader and get data, model will work normally.
Following code will block
def func(net, d):
out = net(d)
if __name__ == '__main__':
net = Net(input_w=28*28, width=64, n_layer=3, output_w=10) #dense network
trainloader = get_data() #get trainloader
data, label = iter(trainloader).next()
data = data.view(data.size(0), -1)
with Pool() as pool:
pool.starmap(func, [(net, data)])
Following code works normally
def func(net):
trainloader = get_data() #get trainloader
data, label = iter(trainloader).next()
data = data.view(data.size(0), -1)
out = net(data)
if __name__ == '__main__':
net = Net(input_w=28*28, width=64, n_layer=3, output_w=10) #dense network
with Pool() as pool:
pool.starmap(func, [(net, )])
I don’t know what reason caused this problem. Thanks everyone.
hmm, the following code works for me locally. Can you try this in your dev env?
import torch
from torch.multiprocessing import Pool
def func(net, d):
out = net(d)
print(out)
if __name__ == '__main__':
net = torch.nn.Linear(2, 2)
data = torch.zeros(2, 2)
with Pool() as pool:
pool.starmap(func, [(net, data)])
I tried your code then I found a strange problem!! If I pass torch.zeros(32, 28*28), it can work. But!! If passing (64, 28*28), it hangs. This problem doesn’t happen on my macbook, But happened on Linux PC.