How to use torch.jit.trace to trace CausalLMOutputWithPast?

Call torch.jit.trace to trace the transformers.models.model_qwen2 model by the following code:

    huggingface_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    torch_model = modeling_qwen2.Qwen2ForCausalLM.from_pretrained(
        huggingface_model_name,
        config=llm_config,
        ignore_mismatched_sizes=False,
    )
    ...
    pt_model = torch.jit.trace(torch_model.to("cpu"), model_inputs.input_ids)

The error happened:
RuntimeError: Tracer cannot infer type of CausalLMOutputWithPast(loss=None, logits=tensor...)

The CausalLMOutputWithPast isn’t a class inherited from nn.Module, please refer to:
Model outputs

How can I handle this case?

Thanks for the great work of pytorch team!

Extract logits from CausalLMOutputWithPast, then trace the model with torch.jit.trace on a function returning only logits.