Issues tracing GPT-2 models for mobile deployment

Hey folks,

Problem Description

I am attempting to prepare models based on GPT-2 for deployment to a mobile application using the PyTorch JIT tracing capability, mobile optimization, and lite-interpreter saving (as recommended in various PyTorch formulas / examples). We make use of past_key_values, so have adjusted those examples to account for that input to the model.

The tracing / export process appears to work as expected, but calls to forward on the resulting model result in the following error on mobile (and similar on desktop):

com.facebook.jni.CppException: XNNPACK Linear not available! Reason: The provided (weight, bias, output_min, output_max) parameters are either invalid individually or their combination is not supported by XNNPACK.

Steps to reproduce

This was run on an x86-64 system using PyTorch 2.3.0. I have also reproduced the same error in our mobile environment (using the Android API version 2.1.0, the latest available). Running the following script should reproduce the error, specifically when calling loaded_model to advance:

import torch

from transformers import AutoModelForCausalLM

from torch.utils.mobile_optimizer import optimize_for_mobile

from concurrent import futures

def trace_and_save(model, example_input, example_past, filename):
    class ModelWrapper(torch.nn.Module):
        def __init__(self, model):
            self.model = model

        def forward(self, x, past):
            return self.model(x, return_dict=False, past_key_values=past)

    wrapped_model = ModelWrapper(model)

    traced_model = torch.jit.trace(
        (example_input, example_past),

    optimized_model = optimize_for_mobile(traced_model)


def main():
    with torch.no_grad():

        model = AutoModelForCausalLM.from_pretrained(
            "datificate/gpt2-small-spanish", torchscript=True


        N, C = 1, 10
        dummy_input = torch.zeros((N, C), dtype=torch.long)

        def make_tensor_pair():
            return (
                torch.zeros(1, 12, 0, 64, dtype=torch.float),
                torch.zeros(1, 12, 0, 64, dtype=torch.float),

        dummy_past = tuple(make_tensor_pair() for _ in range(12))

        trace_and_save(model, dummy_input, dummy_past, "gpt2.ptl")

        loaded_model = torch.jit.load("gpt2.ptl")

        all_encoder_layers, _ = loaded_model(dummy_input, dummy_past)

if __name__ == "__main__":

Any ideas or advice would be very welcome, thanks!

For anyone that comes across this thread: We figured out that it was the combination of optimize_for_mobile and set_default_dtype(torch.float64) that was causing our problems.

We introduced the latter due to warnings emitted during the trace regarding a few ‘mismatched’ elements (i.e. elements that were out of the PyTorch default tolerance range). This is acceptable for our application, so we simply reverted to the default 32-bit float tensor type.

Alternatively, we found that skipping the optimize_for_mobile step worked if that precision is important – though there was a significant drop in performance in the mobile setting, as you may expect.