Performance collapse under high concurrency CPU inference (thread oversubscription)

When running multiple concurrent CPU inference requests (e.g., via threading), PyTorch exhibits a severe performance collapse under moderate to high concurrency, even when the system has sufficient vCPUs.

For example, on a 16 vCPU Docker container:

  • A single concurrent user achieves ~40 RPS.

  • But with 6 concurrent users, the throughput drops drastically to ~3 RPS, despite the hardware being capable of much higher throughput (up to 80 RPS with torch.set_num_threads(2) under the same load).

The same behavior is observed across different models and is especially severe on modern CPUs such as Intel Sapphire Rapids. The performance degradation occurs consistently when torch.set_num_threads is not set and concurrency is ≥2, regardless of the total available CPU resources.

Minimal Reproduction Script:

from transformers import AutoTokenizer
import torch
import time
import threading
import os

model_path = "path/to/my/RoBERTa Model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
pt_model = torch.jit.load(model_path + '/model.pt')
pt_model.eval()

text_50_tokens = "xxx xxx xxx"
num_of_requests = 200
num_of_users = 6

def simulate_user(user_id, num_requests):
    for i in range(num_requests):
        encoded = tokenizer(text_50_tokens, padding=True, truncation=True, return_tensors='pt')
        input_ids = encoded["input_ids"]
        attention_mask = encoded['attention_mask']
        embeddings = pt_model(input_ids, attention_mask).detach().cpu().numpy()
        print(f"User {user_id} finished {i}th requests.")
    print(f"User {user_id} finished all {num_requests} requests!!!")

start_time = time.time()

threads = []
for user_id in range(num_of_users):
    t = threading.Thread(target=simulate_user, args=(user_id, num_of_requests))
    threads.append(t)
    t.start()

for t in threads:
    t.join()

end_time = time.time()
total_time = end_time - start_time
total_requests = num_of_users * num_of_requests

print(f"\nTotal requests: {total_requests}")
print(f"Total time: {round(total_time, 2)} seconds")
print(f"Average RPS: {round(total_requests / total_time, 2)} per second")

The optimal torch.set_num_threads() value varies across models, hardware (e.g., Ice Lake vs. Sapphire Rapids), and workloads (e.g., 1 vs. 6 concurrent users), making it difficult to configure a single setting that avoids thread oversubscription and maintain high performance. Is there a mechanism in PyTorch to dynamically adjust the number of threads to achieve consistently good inference throughput?