Hi, first time posting, apologies if I made a mistake in the categorization or anything.
I am trying to have a generator load objects in the background, and I am encountering an extremely strange bug which I have distilled down to the following example. When I try to run the following code, it hangs when trying to call torch.zeros
in split_loader_creator
, but if I remove the seemingly irrelevant line torch.zeros(152*4, 168*4).float()
near the end, it seemingly can make progress. It also seems fine if I change 1524 and 1684 to much smaller numbers. This is on PyTorch 1.5.1, and I do not encounter the issue on 1.4.0. Am I somehow doing this multiprocessing incorrectly? Would really appreciate any help, and please let me know if I can provide more information that might be helpful.
import torch
import multiprocessing
import atexit
def split_loader_creator():
for i in range(20):
yield torch.zeros(10, 170, 70)
def background_generator_helper(gen_creator):
def _bg_gen(gen_creator, conn):
gen = gen_creator()
while conn.recv():
try:
conn.send(next(gen))
except StopIteration:
conn.send(StopIteration)
return
except Exception:
import traceback
traceback.print_exc()
parent_conn, child_conn = multiprocessing.Pipe()
p = multiprocessing.Process(target=_bg_gen, args=(gen_creator, child_conn))
p.start()
atexit.register(p.terminate)
parent_conn.send(True)
while True:
parent_conn.send(True)
x = parent_conn.recv()
if x is StopIteration:
return
else:
yield x
def background_generator(gen_creator): # get several processes in the background fetching batches in parallel to keep up with gpu
generator = background_generator_helper(gen_creator)
while True:
batch = next(generator)
if batch is StopIteration:
return
yield batch
torch.zeros(152*4, 168*4).float()
data_loader = background_generator(split_loader_creator)
for i, batch in enumerate(data_loader):
print(i)