I have an inference function
func that takes on average 9 sec to run. But when I try to use multiprocessing to parallelize it (even using torch.multiprocessing) each inference takes on average 20 sec why is that ?
func is an inference function which takes in a
patient_name and runs a torch model in inference on that patient’s data.
device = torch.device(torch.device('cpu')) def func(patient_name): data = np.load(my_dict[system_name]['data_path']) model_state = torch.load(my_dict[system_name]['model_state_path'],map_location='cpu') model = my_net(my_dict[system_name]['HPs']) model = model.to(device) model.load_state_dict(model_state) model.eval() result = model(torch.FloatTensor(data).to(device)) return result
from torch.multiprocessing import pool core_cnt = 10 pool = Pool(core_cnt) out = pool.starmap(func, pool_args)