What is the best way to use torch.compile
for generative models (i.e. invoking model.generate)?
Assuming the following code:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
AutoTokenizer.from_pretrained(
"microsoft/phi-2", torch_dtype=torch.bfloat16
)
AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
input_str = '''def print_prime(n):
"""
Print all primes between 1 and n
"""'''
inputs = tokenizer(input_str, return_tensors="pt", return_attention_mask=False, max_length=128)
results = model.generate(inputs)
decoded_output = tester.tokenizer.decode(results[0])
When does one use torch.compile
? I am using my own backend, so I noticed the following:
model = torch.compile(model, backend=my_backend...)
outputs = model.generate(inputs)
doesn’t actually run on my backend.
I could do either of the following though:
model.forward = torch.compile(model.forward, backend=my_backend...)
outputs = model.generate(inputs)
or
fnc = torch.compile(model.generate, backend=my_backend...)
outputs = fnc(inputs)
My question is - what’s the best practice for using torch.compile for generative models? What’s the tradeoff between each ways to use torch.compile
?