I’m following the English to German BERT example here: https://pytorch.org/hub/pytorch_fairseq_translation/
I’d like to export this model to ONNX to use for inference on ONNXRuntime. I’ve found a tutorial here: https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
But the step where I do the torch.onnx.export
is failing. I’m thinking the issue is that I’m not 100% sure of the shape or number of inputs. Typically I think there are 3 inputs to BERT all of the same shape (e.g. (batch_size, 256) or (batch_size, 1024)). Is there some method to probe the Hub model and find out the input and output names and shapes it expects?
Would anyone be able to help me spot my mistake?
import torch
# Load an En-De Transformer model trained on WMT'19 data:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', tokenizer='moses', bpe='fastbpe')
bert_model = en2de.models[0]
# Export the model
batch_size = 1
x = torch.ones((batch_size, 1024), dtype=torch.long)
y = torch.ones((batch_size, 1024), dtype=torch.long)
z = torch.ones((batch_size, 1024), dtype=torch.long)
torch.onnx.export(bert_model, # model being run
(x,y,z), # model input (or a tuple for multiple inputs)
"bert_en2de.onnx", # where to save the model (can be a file or file-like object)
export_params=True) # store the trained parameter weights inside the model file
–> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/jit/init.py in forward(self, *args)
358 in_vars + module_state,
359 _create_interpreter_name_lookup_fn(),
–> 360 self._force_outplace,
361 )
362
Error is:
RuntimeError: hasSpecialCase INTERNAL ASSERT FAILED at /opt/conda/conda-bld/pytorch_1579022119164/work/torch/csrc/jit/passes/alias_analysis.cpp:300, please report a bug to PyTorch. We don’t have an op for aten::uniform but it isn’t a special case. (analyzeImpl at /opt/conda/conda-bld/pytorch_1579022119164/work/torch/csrc/jit/passes/alias_analysis.cpp:300)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7fa0fd3ce627 in /home/sdp/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: + 0x304553b (0x7fa10063953b in /home/sdp/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/lib/libtorch.so)