I am using simple BLSTM architecture with 3 layers and I am trying to use multiprocessing.Pool to make inference faster. But it is not working. Below is the snippet of the code. Please suggest solution.
def do_enroll_step(batch_index, batch_data, model):
model.eval()
batch_feature, _, _, batch_spkid = batch_data
with torch.no_grad():
batch_feature = batch_feature.contiguous()
hidden = model.init_hidden(len(batch_feature))
batch_output, _ = model(batch_feature, hidden)
with h5py.File(enroll_result_file, 'a') as f:
f.create_dataset(batch_spkid[0], data = batch_output)
def do_enroll():
starttime = time.time()
num_processes = 5
pool = mp.Pool(processes = num_processes)
pool.starmap_async(do_enroll_step, [(batch_index, batch_data, model) for batch_index, batch_data in enumerate(enroll_loader)])
pool.close()
pool.join()
print('That took {} seconds'.format(time.time() - starttime))