# export_nanogpt.py
import torch
from executorch.exir import EdgeCompileConfig, to_edge
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch._export import capture_pre_autograd_graph
from torch.export import export
import copy
from executorch.backends.apple.mps.partition.mps_partitioner import MPSPartitioner
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.backend.backend_api import to_backend
from model import GPT
# Load the model.
model = GPT.from_pretrained('gpt2')
# Create example inputs. This is used in the export process to provide
# hints on the expected shape of the model input.
example_inputs = (torch.randint(0, 100, (1, model.config.block_size), dtype=torch.long), )
# Trace the model, converting it to a portable intermediate representation.
# The torch.no_grad() call tells PyTorch to exclude training-specific logic.
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
m = capture_pre_autograd_graph(model, example_inputs)
traced_model = export(m, example_inputs)
# Convert the model into a runnable ExecuTorch program.
edge_config = EdgeCompileConfig(_check_ir_validity=False)
edge_manager = to_edge(traced_model, compile_config=edge_config)
edge_copy = copy.deepcopy(edge_manager)
compiler_specs = [CompileSpec("use_fp16", bytes([True]))]
et_delegate = edge_copy.to_backend(MPSPartitioner(compiler_specs))