How do you use Pytorch model's function in Onnx to get output instead of model.forward() function

TL;DR: How can I use model.whatever_function(input) instead of model.forward(input) for the onnxruntime?

I use CLIP embedding to create embedding for my Image and texts as:

Code is from the official git merge

! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

import clip
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model, preprocess = clip.load("RN50", device=device) # Load any model
model = model.eval() # Inference Only

img_size = model.visual.input_resolution
dummy_image = torch.randn(10, 3, img_size, img_size).to(device)
image_embedding = model.encode_image(dummy_image).to(device))

dummy_texts = clip.tokenize(["quick brown fox", "lorem ipsum"]).to(device)
model.encode_text(dummy_texts)

and it works fine giving me [Batch, 1024] tensors for both for the loaded model.

Now I have quantized my model in Onnx as:

model.forward(dummy_image,dummy_texts) # Original CLIP result (1)

torch.onnx.export(model, (dummy_image, dummy_texts), "model.onnx", export_params=True,
  input_names=["IMAGE", "TEXT"],
  output_names=["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"],
  opset_version=14,
  dynamic_axes={
      "IMAGE": {
          0: "image_batch_size",
      },
      "TEXT": {
          0: "text_batch_size",
      },
      "LOGITS_PER_IMAGE": {
          0: "image_batch_size",
          1: "text_batch_size",
      },
      "LOGITS_PER_TEXT": {
          0: "text_batch_size",
          1: "image_batch_size",
      },
  }
)

and the model is saved.

When I test the model as :

# Now run onnxruntime to verify
import onnxruntime as ort

ort_sess = ort.InferenceSession("model.onnx")
result=ort_sess.run(["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"], 
  {"IMAGE": dummy_image.numpy(), "TEXT": dummy_texts.numpy()})

It gives me a list of length 2, one for each image and text and the result[0] has shape of [Batch,2].

looks like this is not related to quantization, could you add a different tag, looks like we don’t have a tag for onnxruntime right now, maybe you can just remove the tag and there might be some that is able to help