Cannot export model in bfp16 to ONNX

Hi,

I have a huggingface model trained with bfp16. I tried to load the model with bfp16 and export it using torch.onnx.export, but got the following error RuntimeError: unexpected tensor scalar type. My code/detailed error is below. Is this because of the bfp16 exporting is not supported?

If I change the model loading code to pegasus_model = PegasusForConditionalGeneration.from_pretrained(model_path), the export below can work. Is there a way to export a model with bfp16 weights?

I was using torch version 1.10.1+cu113 on an a10g gpu.

model_path = "xx"
pegasus_model = PegasusForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16)
tokenizer = PegasusTokenizer.from_pretrained(model_path)
def export_encoder(model, args, exported_model_path):
    model.eval()
    with torch.no_grad():
        _ = torch.onnx._export(model,
                           args,
                           exported_model_path,
                           export_params=True,
                           opset_version=12,
                           input_names=['input_ids'],
                           output_names=['hidden_states'],
                           dynamic_axes={
                               'input_ids': {0:'batch', 1: 'sequence'},
                               'hidden_states': {0:'batch', 1: 'sequence'},
                           })
export_text = "This is a great one"
export_input = tokenizer(export_text, return_tensors='pt')
export_encoder(model.model.encoder, export_input['input_ids'], output_encoder_path)

Detailed Error:

~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/init.py in _export(*args, **kwargs)
26 def _export(*args, **kwargs):
27 from torch.onnx import utils
—> 28 result = utils._export(*args, **kwargs)
29 return result
30
~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, use_external_data_format, onnx_shape_inference)
722
723 graph, params_dict, torch_out =
→ 724 _model_to_graph(model, args, verbose, input_names,
725 output_names, operator_export_type,
726 example_outputs, val_do_constant_folding,
~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
495 params_dict = _get_named_param_dict(graph, params)
496
→ 497 graph = _optimize_graph(graph, operator_export_type,
498 _disable_torch_constant_prop=_disable_torch_constant_prop,
499 fixed_batch_size=fixed_batch_size, params_dict=params_dict,
~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
214 dynamic_axes = {} if dynamic_axes is None else dynamic_axes
215 torch._C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
→ 216 graph = torch._C._jit_pass_onnx(graph, operator_export_type)
217 torch._C._jit_pass_lint(graph)
218
~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/init.py in _run_symbolic_function(*args, **kwargs)
371 def _run_symbolic_function(*args, **kwargs):
372 from torch.onnx import utils
→ 373 return utils._run_symbolic_function(*args, **kwargs)
374
375
~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/utils.py in _run_symbolic_function(g, block, n, inputs, env, operator_export_type)
1030 return None
1031 attrs = {k: n[k] for k in n.attributeNames()}
→ 1032 return symbolic_fn(g, *inputs, **attrs)
1033
1034 elif ns == “prim”:
~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py in wrapper(g, *args, **kwargs)
170 if len(kwargs) == 1:
171 assert “_outputs” in kwargs
→ 172 return fn(g, *args, **kwargs)
173
174 return wrapper
~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/symbolic_opset9.py in embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse)
494 “ONNX does not support not updating the embedding vector at padding_idx during training.”)
495
→ 496 return g.op(“Gather”, weight, indices)
497
498
~/pegasus_exp/lib/python3.9/site-packages/torch/onnx/utils.py in _graph_op(g, opname, *raw_args, **kwargs)
926 if _onnx_shape_inference:
927 from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
→ 928 torch._C._jit_pass_onnx_node_shape_type_inference(n, _params_dict, opset_version)
929
930 if outputs == 1:
RuntimeError: unexpected tensor scalar type