Pytorch2.0 compiled model does not work for huggingface API model.generate()

When I’m working with the following code:

import torch
from transformers import BertTokenizer, BertLMHeadModel
import os
import torch._inductor.config as config

config.debug = True
os.environ["AOT_FX_GRAPHS"] = "1"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', legacy = False)
model = BertLMHeadModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16).to(device)

text = "test test nonsense words"
encoded_input = tokenizer(text, return_tensors="pt").to(device)

# compile the model
compiled_model = torch.compile(model) 

output_length = 32
input_length = len(encoded_input.input_ids[0])
print("input_ids shape : " + str(encoded_input.input_ids.shape))

# huggingface API generate
generate_ids = compiled_model.generate(encoded_input.input_ids, min_new_tokens =output_length, max_new_tokens=output_length)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for outstr in output : 

I want to use the generate() function with the compiled model, but it appears that the compiled model (instance of the class OptimizedModule) does not have the generate() function. Instead, calling generate() triggers _getattr _() and the function is executed by the original BertLMHeadModel.

Is there a method to enable calling generate() with the compiled model?

Try compiling the generate function directly torch.compile(model.generate)