ONNX Export with Multiple Arguments: Only tuples, lists and Variables are supported as JIT inputs/outputs

Hi, everyone!

Issue Summary:

I’m trying to convert the ESM-1b protein transformer model from PyTorch to ONNX. I run into an issue when I want to provide an extra argument to the model.

Conversion Steps, Code, and Error Message:

Here is my conversion script (named convert_onnx_esm.py):

import os
import torch
import torch.onnx
import argparse
from esm.pretrained import load_model_and_alphabet_local


parser = argparse.ArgumentParser()

parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--converted-model-path", type=str, required=True)
commandline_args = parser.parse_args()

model, alphabet = load_model_and_alphabet_local(commandline_args.model_path)
batch_converter = alphabet.get_batch_converter()

data = [
    ("protein1", "VLAGG"),
    ("protein2", "KALTARQ"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

# an example forward pass would be: model(batch_tokens, repr_layers=[33])
# the conversion works if (batch_tokens, [33]) is changed to just: batch_tokens
with torch.no_grad():
    torch.onnx.export(model,
        (batch_tokens, [33]),  
        commandline_args.converted_model_path,
        use_external_data_format=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=["inputs"],
        output_names=["outputs"],
        dynamic_axes={"inputs": [0, 1]}
    )

Which I run like follows:

export MODEL_PATH=/tmp/models/esm/esm1b_t33_650M_UR50S.pt
export CONVERTED_GRAPH_PATH=/tmp/models/onnx_esm/graph.onnx

mkdir -p $(dirname $MODEL_PATH) $(dirname $CONVERTED_GRAPH_PATH)
curl https://dl.fbaipublicfiles.com/fair-esm/models/esm1b_t33_650M_UR50S.pt --output /tmp/models/esm/esm1b_t33_650M_UR50S.pt  # this is a 7Gb file

python convert_onnx_esm.py --model-path $MODEL_PATH --converted-model-path $CONVERTED_GRAPH_PATH

And the error I get after running the conversion script:

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: int

Environment:

Cuda 11.2, CudNN 8.1.1.33, and Python 3.8.5 with packages:

fair-esm==0.3.0
onnx==1.8.1
onnxconverter-common==1.6.0
onnxruntime-gpu==1.7.0
onnxruntime-tools==1.6.0
torch==1.9.0.dev20210318

Additional Info:

To no avail, I also tried changing the arguments (batch_tokens, [33]) to the dictionary format:

(batch_tokens, {"tokens": batch_tokens, "repr_layers": [33]})

And finally, in case it’s helpful, the nn.Module.forward() method definition starts like this:

def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False)

Thank you for any tips or pointers!

I realized that passing the argument as a list is not valid. A tensor is needed, like this:

torch.onnx.export(model,
    (batch_tokens, torch.tensor([33])),
    converted_model_path,
    use_external_data_format=True,
    ...
)

Unfortunately, torch.tensor([33])) is as good as [] as far as the model’s behavior is concerned.