Hi, everyone!
Issue Summary:
I’m trying to convert the ESM-1b protein transformer model from PyTorch to ONNX. I run into an issue when I want to provide an extra argument to the model.
Conversion Steps, Code, and Error Message:
Here is my conversion script (named convert_onnx_esm.py
):
import os
import torch
import torch.onnx
import argparse
from esm.pretrained import load_model_and_alphabet_local
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--converted-model-path", type=str, required=True)
commandline_args = parser.parse_args()
model, alphabet = load_model_and_alphabet_local(commandline_args.model_path)
batch_converter = alphabet.get_batch_converter()
data = [
("protein1", "VLAGG"),
("protein2", "KALTARQ"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
# an example forward pass would be: model(batch_tokens, repr_layers=[33])
# the conversion works if (batch_tokens, [33]) is changed to just: batch_tokens
with torch.no_grad():
torch.onnx.export(model,
(batch_tokens, [33]),
commandline_args.converted_model_path,
use_external_data_format=True,
opset_version=11,
do_constant_folding=True,
input_names=["inputs"],
output_names=["outputs"],
dynamic_axes={"inputs": [0, 1]}
)
Which I run like follows:
export MODEL_PATH=/tmp/models/esm/esm1b_t33_650M_UR50S.pt
export CONVERTED_GRAPH_PATH=/tmp/models/onnx_esm/graph.onnx
mkdir -p $(dirname $MODEL_PATH) $(dirname $CONVERTED_GRAPH_PATH)
curl https://dl.fbaipublicfiles.com/fair-esm/models/esm1b_t33_650M_UR50S.pt --output /tmp/models/esm/esm1b_t33_650M_UR50S.pt # this is a 7Gb file
python convert_onnx_esm.py --model-path $MODEL_PATH --converted-model-path $CONVERTED_GRAPH_PATH
And the error I get after running the conversion script:
RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int
Environment:
Cuda 11.2, CudNN 8.1.1.33, and Python 3.8.5 with packages:
fair-esm==0.3.0
onnx==1.8.1
onnxconverter-common==1.6.0
onnxruntime-gpu==1.7.0
onnxruntime-tools==1.6.0
torch==1.9.0.dev20210318
Additional Info:
To no avail, I also tried changing the arguments (batch_tokens, [33])
to the dictionary format:
(batch_tokens, {"tokens": batch_tokens, "repr_layers": [33]})
And finally, in case it’s helpful, the nn.Module.forward()
method definition starts like this:
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False)
Thank you for any tips or pointers!