Multiprocessing Pool() hangs

I used an python multiprocessing Pool and imap() function in my Dataset init() function to accelerate featurization my input.But my code hangs when initialize the Pool(). Any idea? here are some pseudo code.

from multiprocessing import Pool
class MyDataset(torch.utils.data.Dataset):
    def __init__(input):
         p = Pool(5)
         feature = list(p.imap(input, featurization_func))

ps: The MyDataset was then loaded with torch.utils.data.Dataloader and number_workers is set by 4.

Help appreciated! Thanks in advance~

Hmm it’s been well over a year and no answers. I have the same issue now. I’m taking a model and trying to do inference and it just hangs and never returns.

import torch
import pandas as pd
import glob2
import bz2
import json
import random
import re
import numpy as np
import tqdm
from time import perf_counter
from transformers import AutoModel, AutoTokenizer, BertweetTokenizer
from multiprocessing import Pool

model = AutoModel.from_pretrained("vinai/bertweet-base")
model.share_memory()
# For transformers v4.x+:
# tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", use_fast=False)
tokenizer = BertweetTokenizer.from_pretrained("vinai/bertweet-base", use_fast=True)

def inference(path):
    with open(path, 'r') as f:
        user_features = []
        filename = path.split('/')[-1]
        matches = re.search(r"Tweets_([a-zA-Z0-9_]+)",filename)
        tweets = f.readlines()
        max_tweets = max(100,len(tweets))
        tweets = random.sample(tweets, min(100,len(tweets)))
        num_tweets = len(tweets)
        print(filename,num_tweets)
        with torch.no_grad():
            for tweet in tweets:
                print('.',end='')
                json_line = json.loads(tweet)
                user_id = json_line['user_id']
                username = json_line['username']
                tweet = json_line['tweet']
                try:
                    normalizedTweet = tokenizer.normalizeTweet(tweet)
                    input_ids = torch.tensor([tokenizer.encode(normalizedTweet)])
                    if len(input_ids) > 128:
                        raise ValueError("too many tokens")
                    features = model(input_ids)['pooler_output']
                    user_features.append(features.detach().numpy().squeeze())
                except ValueError as e:
                    print(filename, tweet, e)
                except IndexError as e:
                    print(filename, e, tweet)
                except KeyError as e:
                    print(filename, e, tweet)
            output = {
                'username': matches.groups()[0],
                'features': {'type':1,'values':np.array(user_features).mean(axis=0).tolist()},
                'weight': len(user_features),
                'max_tweets': max(len(user_features),max_tweets)
            }
    return output

if __name__=='__main__':
    start = perf_counter()
    files = glob2.glob('Tweets/*.json')
    sampled_files = random.sample(files,16)
    print(f"{len(sampled_files)} being processed...")
    summaries = []
    with Pool(4) as p:
        result = p.map(inference, sampled_files)
        p.close()
        p.join()
    end = perf_counter()
    m, s = divmod(end-start, 60)
    print(f'time {m} minutes {s} seconds')
    print(result)

When I take out all the statements involving the torch model, the Pool works, so it’s definitely something related to the model (and/or tokenizer).