Hi, first time posting, apologies if I made a mistake in the categorization or anything.
I am trying to have a generator load objects in the background, and I am encountering an extremely strange bug which I have distilled down to the following example. When I try to run the following code, it hangs when trying to call
split_loader_creator, but if I remove the seemingly irrelevant line
torch.zeros(152*4, 168*4).float() near the end, it seemingly can make progress. It also seems fine if I change 1524 and 1684 to much smaller numbers. This is on PyTorch 1.5.1, and I do not encounter the issue on 1.4.0. Am I somehow doing this multiprocessing incorrectly? Would really appreciate any help, and please let me know if I can provide more information that might be helpful.
import torch import multiprocessing import atexit def split_loader_creator(): for i in range(20): yield torch.zeros(10, 170, 70) def background_generator_helper(gen_creator): def _bg_gen(gen_creator, conn): gen = gen_creator() while conn.recv(): try: conn.send(next(gen)) except StopIteration: conn.send(StopIteration) return except Exception: import traceback traceback.print_exc() parent_conn, child_conn = multiprocessing.Pipe() p = multiprocessing.Process(target=_bg_gen, args=(gen_creator, child_conn)) p.start() atexit.register(p.terminate) parent_conn.send(True) while True: parent_conn.send(True) x = parent_conn.recv() if x is StopIteration: return else: yield x def background_generator(gen_creator): # get several processes in the background fetching batches in parallel to keep up with gpu generator = background_generator_helper(gen_creator) while True: batch = next(generator) if batch is StopIteration: return yield batch torch.zeros(152*4, 168*4).float() data_loader = background_generator(split_loader_creator) for i, batch in enumerate(data_loader): print(i)