Torch.compile usage for generative models

What is the best way to use torch.compile for generative models (i.e. invoking model.generate)?

Assuming the following code:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
AutoTokenizer.from_pretrained(
    "microsoft/phi-2", torch_dtype=torch.bfloat16
)
AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
input_str = '''def print_prime(n):
            """
            Print all primes between 1 and n
            """'''
inputs = tokenizer(input_str, return_tensors="pt", return_attention_mask=False, max_length=128)
results = model.generate(inputs)
decoded_output = tester.tokenizer.decode(results[0])

When does one use torch.compile? I am using my own backend, so I noticed the following:

model = torch.compile(model, backend=my_backend...)
outputs = model.generate(inputs)

doesn’t actually run on my backend.
I could do either of the following though:

model.forward = torch.compile(model.forward, backend=my_backend...)
outputs = model.generate(inputs)

or

fnc = torch.compile(model.generate, backend=my_backend...)
outputs = fnc(inputs)

My question is - what’s the best practice for using torch.compile for generative models? What’s the tradeoff between each ways to use torch.compile?

You could add the decorator to the method you would like compile via:

@torch.compile
def generate(...):
   ...

Thanks for the reply!
I noticed that my model gets killed due to out-of-memory when running with

model.forward = torch.compile(model.forward, backend=my_backend...)
outputs = model.generate(inputs)

But it can run when I do

fnc = torch.compile(model.generate, backend=my_backend...)
outputs = fnc(inputs)

Is this expected?

I assume your process gets killed from the OS as you are running out of memory on the host?

Yes exactly, so I was wondering if there is a recommended way of compiling generative models, or common effects to be aware of.

You could check if reducing the number of threads via the env variable TORCHINDUCTOR_COMPILE_THREADS could help avoiding the OOM.

2 Likes