I have an inference function func
that takes on average 9 sec to run. But when I try to use multiprocessing to parallelize it (even using torch.multiprocessing) each inference takes on average 20 sec why is that ?
For info:
func
is an inference function which takes in a patient_name
and runs a torch model in inference on that patient’s data.
device = torch.device(torch.device('cpu'))
def func(patient_name):
data = np.load(my_dict[system_name]['data_path'])
model_state = torch.load(my_dict[system_name]['model_state_path'],map_location='cpu')
model = my_net(my_dict[system_name]['HPs'])
model = model.to(device)
model.load_state_dict(model_state)
model.eval()
result = model(torch.FloatTensor(data).to(device))
return result
from torch.multiprocessing import pool
core_cnt = 10
pool = Pool(core_cnt)
out = pool.starmap(func, pool_args)