Torch gets slower when upgrading the version

Hi, just following up on this. I have produced a minimal example where there is a clear regression in text encoding speed with an SBERT model. This example uses torch.synchronize, removes downloading & random text generation (uses one hard-coded sentence instead), proving that the encoding really slows down.

import argparse
import threading
import time
from typing import List, Tuple

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel

print("torch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)

DEVICE = "cuda"
device = torch.device(DEVICE)

# Configurable parameters
parser = argparse.ArgumentParser(description='Pytorch Performance Test RPS Script')
parser.add_argument('--Threads', help='The number of threads', default=5, type=int)
parser.add_argument('--Requests', help='The number of requests', default=100, type=int)
parser.add_argument('--Model', help='The model', default="sentence-transformers/all-MiniLM-L6-v2", type=str)

args = parser.parse_args()
NUM_THREADS = args.Threads
NUM_REQUESTS = args.Requests
TARGET_MODEL = args.Model
test_text = 'A herd of horses standing on top of a lush green field.', 'A bird sitting on top of a wooden framed mirror.'

# Loading model
tokenizer = AutoTokenizer.from_pretrained(TARGET_MODEL)
model = AutoModel.from_pretrained(TARGET_MODEL).to(device)


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


def inference_time_on_text(text: str) -> dict:
    times = {}

    start_total = time.time()
    start = time.time()

    torch.cuda.synchronize()

    # Tokenize sentences
    encoded_input = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors='pt').to(device)

    torch.cuda.synchronize()
    elapsed_time = time.time() - start
    times['tokenize'] = elapsed_time

    start = time.time()
    torch.cuda.synchronize()

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    torch.cuda.synchronize()
    elapsed_time = time.time() - start
    times['embeddings'] = elapsed_time

    start = time.time()
    torch.cuda.synchronize()
    # Perform pooling. In this case, mean pooling
    _ = mean_pooling(model_output, encoded_input['attention_mask'])

    torch.cuda.synchronize()

    elapsed_time = time.time() - start
    times['pooling'] = elapsed_time

    elapsed_time = time.time() - start_total
    times['total'] = elapsed_time

    return times


class RequestThread(threading.Thread):
    def __init__(self, text):
        super().__init__()
        self.text = text
        self.toekenize_times = []
        self.embeddings_times = []
        self.pooling_times = []
        self.total_times = []

    def run(self):
        for _ in range(NUM_REQUESTS):
            try:
                times = inference_time_on_text(self.text)
                self.toekenize_times.append(times['tokenize'])
                self.embeddings_times.append(times['embeddings'])
                self.pooling_times.append(times['pooling'])
                self.total_times.append(times['total'])
            except Exception as e:
                print(f"Error ({e})")


def warm_up_calls():
    _ = inference_time_on_text("hello world")


def data_processed(time_list: List[float]) -> Tuple:
    sample_size = len(time_list)

    time_list_copy = np.copy(time_list)
    mean = np.mean(time_list_copy)
    p50 = np.percentile(time_list_copy, 50)
    p90 = np.percentile(time_list_copy, 90)
    p99 = np.percentile(time_list_copy, 99)
    return sample_size, mean, p50, p90, p99


def multi_threaded_test():
    # Create the threads
    threads: List[RequestThread] = list()
    for i in range(NUM_THREADS):
        threads.append(RequestThread(test_text))

    print(f'Starting threads')
    start = time.time()
    torch.cuda.synchronize()
    [t.start() for t in threads]
    [t.join() for t in threads]
    torch.cuda.synchronize()
    end = time.time()
    elapsed = end - start

    total_requests = sum(len(t.total_times) for t in threads)
    latencies = [l for t in threads for l in t.total_times]

    sample_size, mean, p50, p90, p99 = data_processed(latencies)
    print("-----------------------------------------")
    print(f"RESULTS FOR SBERT {type}")
    print(f"TORCH: {torch.__version__}, CUDA: {torch.version.cuda}")
    print(f"MODEL: {TARGET_MODEL}, DEVICE: {DEVICE}")
    print(f"NUM_REQUESTS: {NUM_REQUESTS}, NUM_THREADS: {NUM_THREADS}")
    print("Number of queries inferenced:", sample_size)
    print("---RPS---")
    print(total_requests / elapsed)
    print("---LATENCY---")
    print("Total time taken (s):", elapsed)
    print("Mean (s):", mean)
    print("p50 (s):", p50)
    print("p90 (s):", p90)
    print("p99 (s):", p99)

    # Tokenize
    latencies = [l for t in threads for l in t.toekenize_times]

    sample_size, mean, p50, p90, p99 = data_processed(latencies)
    print("-----------------------------------------")
    print("---Tokenize Time---")
    print("Mean (s):", mean)
    print("p50 (s):", p50)
    print("p90 (s):", p90)
    print("p99 (s):", p99)

    # Embeddings
    latencies = [l for t in threads for l in t.embeddings_times]

    sample_size, mean, p50, p90, p99 = data_processed(latencies)
    print("-----------------------------------------")
    print("---Embedding Time---")
    print("Mean (s):", mean)
    print("p50 (s):", p50)
    print("p90 (s):", p90)
    print("p99 (s):", p99)

    # Pooling
    latencies = [l for t in threads for l in t.pooling_times]

    sample_size, mean, p50, p90, p99 = data_processed(latencies)
    print("-----------------------------------------")
    print("---Pooling Time---")
    print("Mean (s):", mean)
    print("p50 (s):", p50)
    print("p90 (s):", p90)
    print("p99 (s):", p99)

    print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))


def main():
    warm_up_calls()

    multi_threaded_test()


if __name__ == "__main__":
    main()

Side-by-side comparison of Torch 2.1.0 vs Torch 1.12.1:

Observe that Torch 2.1.0 is slower (152 RPS < 186 RPS).

2 Likes