PyTorch hangs in thread after large torch.zeros call in main process

Hi, first time posting, apologies if I made a mistake in the categorization or anything.

I am trying to have a generator load objects in the background, and I am encountering an extremely strange bug which I have distilled down to the following example. When I try to run the following code, it hangs when trying to call torch.zeros in split_loader_creator, but if I remove the seemingly irrelevant line torch.zeros(152*4, 168*4).float() near the end, it seemingly can make progress. It also seems fine if I change 1524 and 1684 to much smaller numbers. This is on PyTorch 1.5.1, and I do not encounter the issue on 1.4.0. Am I somehow doing this multiprocessing incorrectly? Would really appreciate any help, and please let me know if I can provide more information that might be helpful.

import torch
import multiprocessing
import atexit

def split_loader_creator():
    for i in range(20):
        yield torch.zeros(10, 170, 70)

def background_generator_helper(gen_creator):
    def _bg_gen(gen_creator, conn):
        gen = gen_creator()
        while conn.recv():
            try:
                conn.send(next(gen))
            except StopIteration:
                conn.send(StopIteration)
                return
            except Exception:
                import traceback
                traceback.print_exc()

    parent_conn, child_conn = multiprocessing.Pipe()
    p = multiprocessing.Process(target=_bg_gen, args=(gen_creator, child_conn))
    p.start()
    atexit.register(p.terminate)

    parent_conn.send(True)
    while True:
        parent_conn.send(True)
        x = parent_conn.recv()
        if x is StopIteration:
            return
        else:
            yield x

def background_generator(gen_creator): # get several processes in the background fetching batches in parallel to keep up with gpu
    generator = background_generator_helper(gen_creator)
    while True:
        batch = next(generator)
        if batch is StopIteration:
            return
        yield batch

torch.zeros(152*4, 168*4).float()

data_loader = background_generator(split_loader_creator)
for i, batch in enumerate(data_loader):
    print(i)

This sounds like a regression. And I confirm I can reproduce the reported behavior.

Could you please submit an issue to Issues · pytorch/pytorch · GitHub to report this bug?

Thanks, I’ll give that a try.

1 Like