I was trying to compare using torch.compile
to other optimization methods and doing this on CPU results in pretty unexpected behavior to me. I was attempting to compile a huggingface transformers model for text classification. Here are my machine specs
# platform and psutils
Machine type: x86_64
Processor type:
Platform type: Linux-4.19.0-24-cloud-amd64-x86_64-with-glibc2.28
Operating system: Linux
Operating system release: 4.19.0-24-cloud-amd64
Operating system version: #1 SMP Debian 4.19.282-1 (2023-04-29)
Number of physical cores: 4
Number of logical cores: 8
# cpuinfo
python_version: 3.10.11.final.0 (64 bit)
cpuinfo_version: [9, 0, 0]
cpuinfo_version_string: 9.0.0
arch: X86_64
bits: 64
count: 8
arch_string_raw: x86_64
vendor_id_raw: GenuineIntel
brand_raw: Intel(R) Xeon(R) CPU @ 2.30GHz
hz_advertised_friendly: 2.3000 GHz
hz_actual_friendly: 2.3000 GHz
hz_advertised: [2300000000, 0]
hz_actual: [2299998000, 0]
model: 63
family: 6
l3_cache_size: 47185920
l2_cache_size: 262144
l1_data_cache_size: 32768
l1_instruction_cache_size: 32768
l2_cache_line_size: 256
l2_cache_associativity: 6
When I follow this tutorial, I can see the speedup when I just run the same input over and over again with timeit
like the tutorial does. I see about a 40% speedup which I guess is expected.
Where things become weird is when I try to run it on an actual set of unique texts (rather than the same input over and over). In this case the model starts to take an extremely long time to run, upwards of 10s per sample. Can someone help me understand what is going on here? Here is the code I am testing this with
import time
import timeit
import random
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
NUM_ITERS = 50
### FOLLOWING TUTORIAL ###
# load base model
model_id = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
base_model = AutoModelForSequenceClassification.from_pretrained(model_id)
# create inputs
text = "This is just a sample text I am using as a test."
# inputs = tokenizer(text, padding="max_length", return_tensors="pt")
inputs = tokenizer(text, return_tensors="pt")
# compile the model
print("compiling the model...")
comp_model = torch.compile(base_model)
with torch.no_grad():
comp_model(**inputs)
print("running timeit on eager mode...")
with torch.no_grad():
# warmup
for _ in tqdm(range(10)):
base_model(**inputs)
eager_t = timeit.timeit("base_model(**inputs)", number=NUM_ITERS, globals=globals())
print("running timeit on compiled mode...")
with torch.no_grad():
# warmup
for _ in tqdm(range(10)):
comp_model(**inputs)
inductor_t = timeit.timeit("comp_model(**inputs)", number=NUM_ITERS, globals=globals())
# now run on a set of unique texts
sample_size = 10
sample = random.sample(TEXTS, 10)
print("running inference on base model...")
with torch.no_grad():
# warmup
for _ in tqdm(range(10)):
base_model(**inputs)
base_total_time = 0
for text in tqdm(sample):
new_inputs = tokenizer(text, return_tensors='pt')
start = time.time()
base_model(**new_inputs)
base_total_time += time.time() - start
print("running inference on compiled mode...")
with torch.no_grad():
# warmup
for _ in tqdm(range(10)):
comp_model(**inputs)
comp_total_time = 0
for text in tqdm(sample):
new_inputs = tokenizer(text, return_tensors='pt')
start = time.time()
comp_model(**new_inputs)
comp_total_time += time.time() - start
print(f"eager repeat inputs: {eager_t * 1000 / NUM_ITERS} ms/iter")
print(f"inductor repeat inputs: {inductor_t * 1000 / NUM_ITERS} ms/iter")
print(f"speed up ratio: {eager_t / inductor_t}")
print(f"eager unique inputs: {(base_total_time / sample_size) * 1000} ms/iter")
print(f"inductor unique inputs: {(comp_total_time / sample_size) * 1000} ms/iter")
# compiling the model...
# running timeit on eager mode...
# 100%|██████████| 10/10 [00:01<00:00, 9.14it/s]
# running timeit on compiled mode...
# 100%|██████████| 10/10 [00:00<00:00, 36.74it/s]
# running inference on base model...
# 100%|██████████| 10/10 [00:00<00:00, 28.27it/s]
# 100%|██████████| 10/10 [00:00<00:00, 15.81it/s]
# running inference on compiled mode...
# 100%|██████████| 10/10 [00:00<00:00, 38.31it/s]
# 100%|██████████| 10/10 [02:17<00:00, 13.78s/it]
# eager repeat inputs: 35.97641162000059 ms/iter
# inductor repeat inputs: 25.08608974000026 ms/iter
# speed up ratio: 1.4341179511382955
# eager unique inputs: 61.79332733154297 ms/iter
# inductor unique inputs: 13781.427097320557 ms/iter
EDIT: my VM also kept crashing which made me think it was running out of memory. I watched htop during the execution loop. each iteration the memory would jump up by 1G. Is it compiling a new model on each iteration? Why is the memory increasing?