Queue get function get stuck when using multiprocessing

Hi everyone,
I found that when getting Tensors from a multiprocessing queue, the program will be stuck randomly. I wrote a snippet to reproduce this problem:

import torch
import time
from torch.multiprocessing import set_start_method,Queue

try:
    set_start_method('spawn')
except RuntimeError:
    pass

def f(nproc,q,generate_device_list):
    while True:
        training_data_map = {
            "multiview_lumitexel_list":[torch.randn(50,24576,1,device=generate_device_list[i%len(generate_device_list)]) for i in range(8)],
            "rendered_slice_diff_gt_list":[torch.randn(50,24576,1,device=generate_device_list[i%len(generate_device_list)]) for i in range(8)],
            "rendered_slice_spec_gt_list":[torch.randn(50,384,1,device=generate_device_list[i%len(generate_device_list)]) for i in range(8)],
            "normal_label":torch.randn(50,3,device=generate_device_list[0]),
            "input_positions":torch.randn(50,3,device=generate_device_list[0]),
            "geometry_normal":torch.randn(50,3,device=generate_device_list[0])
        }
        # print("[PROCESS] before putting...")
        q.put(training_data_map)
        # print("[PROCESS] put one item")
        # time.sleep(0.5)

if __name__ == '__main__':
    torch.random.manual_seed(2333)
    q = Queue(100)
    generate_device_list = [torch.device("cuda:0"),torch.device("cuda:1"),torch.device("cuda:2"),torch.device("cuda:3")]
    training_device = torch.device("cuda:0")

    ctx = torch.multiprocessing.spawn(f, args=(q,generate_device_list), nprocs=1, join=False, daemon=True)
    time.sleep(30.0)
    print(q.qsize())
    counter = 0
    while True:
        print("[MAIN] before getting...")
        batch_data = q.get()

        input_positions = batch_data["input_positions"].to(training_device,copy=True)
        multiview_lumitexel_list = [a.to(training_device,copy=True) for a in batch_data["multiview_lumitexel_list"]]
        rendered_slice_diff_gt_list = [a.to(training_device,copy=True) for a in batch_data["rendered_slice_diff_gt_list"]]
        rendered_slice_spec_gt_list = [a.to(training_device,copy=True) for a in batch_data["rendered_slice_spec_gt_list"]]
        normal_label = batch_data["normal_label"].to(training_device,copy=True)
        geometry_normal = batch_data["geometry_normal"].to(training_device,copy=True)

        del batch_data
        print("[MAIN] got one item:{}".format(counter))
        counter+=1

    ctx.join()

However, this problem is not dedicated. I got stuck when counter equals 607 or 18901.
My environment:
Linux version 4.15.0-74-generic
Ubuntu 16.04
python3.6.7 or python 3.7.0
pytorch 1.4.0

Can anyone help me? Many thanks before! :dizzy_face: