Hello.
I’m trying to export pt2 file from llm model from transformers library and save it.
from transformers import LlamaForCausalLM, LlamaConfig
import torch
from torch.export import export
config = LlamaConfig()
config.num_hidden_layers = 1
m = LlamaForCausalLM(config)
with torch.no_grad():
ep = export(m, (m.dummy_inputs['input_ids'],))
torch.export.save(ep, 'llama.pt2')
But, when torch.export.save
was called, it raised serialization error.
NotImplementedError: No registered serialization name for <class 'transformers.modeling_outputs.CausalLMOutputWithPast'> found. Please update your _register_pytree_node call with a `serialized_type_name` kwarg.
Is there anything that I should do for the saving pt2 file?