Saving exported program from transformers library

Hello.

I’m trying to export pt2 file from llm model from transformers library and save it.

from transformers import LlamaForCausalLM, LlamaConfig
import torch
from torch.export import export

config = LlamaConfig()
config.num_hidden_layers = 1

m = LlamaForCausalLM(config)

with torch.no_grad():
    ep = export(m, (m.dummy_inputs['input_ids'],))

torch.export.save(ep, 'llama.pt2')

But, when torch.export.save was called, it raised serialization error.

NotImplementedError: No registered serialization name for <class 'transformers.modeling_outputs.CausalLMOutputWithPast'> found. Please update your _register_pytree_node call with a `serialized_type_name` kwarg.

Is there anything that I should do for the saving pt2 file?