Call torch.jit.trace to trace the transformers.models.model_qwen2 model by the following code:
huggingface_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
torch_model = modeling_qwen2.Qwen2ForCausalLM.from_pretrained(
huggingface_model_name,
config=llm_config,
ignore_mismatched_sizes=False,
)
...
pt_model = torch.jit.trace(torch_model.to("cpu"), model_inputs.input_ids)
The error happened:
RuntimeError: Tracer cannot infer type of CausalLMOutputWithPast(loss=None, logits=tensor...)
The CausalLMOutputWithPast isn’t a class inherited from nn.Module, please refer to:
Model outputs
How can I handle this case?
Thanks for the great work of pytorch team!