Error exporting BERT + classification layer to ONNX

I have the following model:

class BertClassifier(nn.Module):
    Class defining the classifier model with a BERT encoder and a single fully connected classifier layer.
    def __init__(self, dropout=0.5, num_labels=24):
        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, num_labels)
        self.relu = nn.ReLU()
        self.best_score = 0

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
        output = self.relu(self.linear(self.dropout(pooled_output)))

        return output

Using these inputs:

    ex_string = "example string"
    inputs = tokenizer(ex_string,
                       padding='max_length', max_length=512, truncation=True,
    input_id = inputs['input_ids'].squeeze(1)
    mask = inputs['attention_mask']

And I export the model to ONNX using:

torch.onnx.export(model, (input_id, mask), 'tryout.onnx', export_params=True, do_constant_folding=True)

Which results in the following stack trace:

/.local/lib/python3.9/site-packages/torch/onnx/ UserWarning: Type cannot be inferred, which might cause exported graph to produce incorrect results.
  warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.")
[W shape_type_inference.cpp:434] Warning: Constant folding in symbolic shape inference fails: index_select(): Index is supposed to be a vector
Exception raised from index_select_out_cpu_ at ../aten/src/ATen/native/TensorAdvancedIndexing.cpp:887 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7fc288ab8d62 in /.local/lib/python3.9/site-packages/torch/lib/
frame #1: at::native::index_select_out_cpu_(at::Tensor const&, long, at::Tensor const&, at::Tensor&) + 0x3a9 (0x7fc2cd9e5189 in /.local/lib/python3.9/site-packages/torch/lib/
frame #2: at::native::index_select_cpu_(at::Tensor const&, long, at::Tensor const&) + 0xe6 (0x7fc2cd9e7146 in /.local/lib/python3.9/site-packages/torch/lib/
frame #3: <unknown function> + 0x1d37f12 (0x7fc2ce0def12 in /.local/lib/python3.9/site-packages/torch/lib/
frame #4: at::_ops::index_select::redispatch(c10::DispatchKeySet, at::Tensor const&, long, at::Tensor const&) + 0xb9 (0x7fc2cdc7a099 in /.local/lib/python3.9/site-packages/torch/lib/
frame #5: <unknown function> + 0x3250ac3 (0x7fc2cf5f7ac3 in /.local/lib/python3.9/site-packages/torch/lib/
frame #6: <unknown function> + 0x32510f5 (0x7fc2cf5f80f5 in /.local/lib/python3.9/site-packages/torch/lib/
frame #7: at::_ops::index_select::call(at::Tensor const&, long, at::Tensor const&) + 0x166 (0x7fc2cdcf9ce6 in /.local/lib/python3.9/site-packages/torch/lib/
frame #8: torch::jit::onnx_constant_fold::runTorchBackendForOnnx(torch::jit::Node const*, std::vector<at::Tensor, std::allocator<at::Tensor> >&, int) + 0x1b5f (0x7fc34fd5d6ff in /.local/lib/python3.9/site-packages/torch/lib/
frame #9: <unknown function> + 0xbbdc22 (0x7fc34fda4c22 in /.local/lib/python3.9/site-packages/torch/lib/
frame #10: torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map<std::string, c10::IValue, std::less<std::string>, std::allocator<std::pair<std::string const, c10::IValue> > > const&, int) + 0xa8e (0x7fc34fdaa46e in /.local/lib/python3.9/site-packages/torch/lib/
frame #11: <unknown function> + 0xbc4f74 (0x7fc34fdabf74 in /.local/lib/python3.9/site-packages/torch/lib/
frame #12: <unknown function> + 0xb35730 (0x7fc34fd1c730 in /.local/lib/python3.9/site-packages/torch/lib/
frame #13: <unknown function> + 0x2a5d8b (0x7fc34f48cd8b in /.local/lib/python3.9/site-packages/torch/lib/
frame #14: python3() [0x53a8eb]
<omitting python frames>
frame #17: python3() [0x50f5e9]
frame #20: python3() [0x50f5e9]
frame #23: python3() [0x50f5e9]
frame #26: python3() [0x50f5e9]
frame #29: python3() [0x50f5e9]
frame #32: python3() [0x50f5e9]
frame #35: python3() [0x608ebb]
frame #36: python3() [0x603ea4]
frame #37: python3() [0x60834d]
frame #41: <unknown function> + 0x2dfd0 (0x7fc352298fd0 in /lib/x86_64-linux-gnu/
frame #42: __libc_start_main + 0x7d (0x7fc35229907d in /lib/x86_64-linux-gnu/
 (function ComputeConstantFolding)
Traceback (most recent call last):
  File "/bert_extraction/", line 71, in <module>
    torch.onnx.export(model, (input_id, mask), 'tryout.onnx', export_params=True, do_constant_folding=True)
  File "/.local/lib/python3.9/site-packages/torch/onnx/", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/.local/lib/python3.9/site-packages/torch/onnx/", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/.local/lib/python3.9/site-packages/torch/onnx/", line 724, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/.local/lib/python3.9/site-packages/torch/onnx/", line 544, in _model_to_graph
    params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
IndexError: index_select(): Index is supposed to be a vector

I can get the ONNX model to compile when I change the do_constant_folding flag to False, but obviously I don’t want to do that, as I’m trying to optimize the inference-time.
Can anyone shed some light on the error or what I’m doing wrong?

1 Like