Hi everyone,
I found that when getting Tensors from a multiprocessing queue, the program will be stuck randomly. I wrote a snippet to reproduce this problem:
import torch
import time
from torch.multiprocessing import set_start_method,Queue
try:
set_start_method('spawn')
except RuntimeError:
pass
def f(nproc,q,generate_device_list):
while True:
training_data_map = {
"multiview_lumitexel_list":[torch.randn(50,24576,1,device=generate_device_list[i%len(generate_device_list)]) for i in range(8)],
"rendered_slice_diff_gt_list":[torch.randn(50,24576,1,device=generate_device_list[i%len(generate_device_list)]) for i in range(8)],
"rendered_slice_spec_gt_list":[torch.randn(50,384,1,device=generate_device_list[i%len(generate_device_list)]) for i in range(8)],
"normal_label":torch.randn(50,3,device=generate_device_list[0]),
"input_positions":torch.randn(50,3,device=generate_device_list[0]),
"geometry_normal":torch.randn(50,3,device=generate_device_list[0])
}
# print("[PROCESS] before putting...")
q.put(training_data_map)
# print("[PROCESS] put one item")
# time.sleep(0.5)
if __name__ == '__main__':
torch.random.manual_seed(2333)
q = Queue(100)
generate_device_list = [torch.device("cuda:0"),torch.device("cuda:1"),torch.device("cuda:2"),torch.device("cuda:3")]
training_device = torch.device("cuda:0")
ctx = torch.multiprocessing.spawn(f, args=(q,generate_device_list), nprocs=1, join=False, daemon=True)
time.sleep(30.0)
print(q.qsize())
counter = 0
while True:
print("[MAIN] before getting...")
batch_data = q.get()
input_positions = batch_data["input_positions"].to(training_device,copy=True)
multiview_lumitexel_list = [a.to(training_device,copy=True) for a in batch_data["multiview_lumitexel_list"]]
rendered_slice_diff_gt_list = [a.to(training_device,copy=True) for a in batch_data["rendered_slice_diff_gt_list"]]
rendered_slice_spec_gt_list = [a.to(training_device,copy=True) for a in batch_data["rendered_slice_spec_gt_list"]]
normal_label = batch_data["normal_label"].to(training_device,copy=True)
geometry_normal = batch_data["geometry_normal"].to(training_device,copy=True)
del batch_data
print("[MAIN] got one item:{}".format(counter))
counter+=1
ctx.join()
However, this problem is not dedicated. I got stuck when counter equals 607 or 18901.
My environment:
Linux version 4.15.0-74-generic
Ubuntu 16.04
python3.6.7 or python 3.7.0
pytorch 1.4.0
Can anyone help me? Many thanks before!