RuntimeError when trying to export pth to onnx

Hello everybody,

When I try to export pth file to onnx, I get an error:

Traceback (most recent call last):
  File "/home/bzeren/projects/visual_tools/LightGlue/torch2onnx.py", line 27, in <module>
    torch.onnx.export(
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 133, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 124, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/bzeren/projects/visual_tools/LightGlue/lightglue/superpoint.py", line 153, in forward
    assert key in data, f"Missing key {key} in data"
  File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/_tensor.py", line 1061, in __contains__
    raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.

Here is my code snippet:

import numpy as np

import torch
import torch.onnx

from lightglue import LightGlue, SuperPoint
from lightglue.utils import load_image, rbd
from lightglue import viz2d

image0 = load_image("/home/bzeren/projects/visual_tools/LightGlue/images/1.jpg")
image1 = load_image("/home/bzeren/projects/visual_tools/LightGlue/images/2.jpg")

# -*- SuperPoint -*-

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SuperPoint(max_num_keypoints=2048)
model.load_state_dict(torch.load('/home/bzeren/projects/visual_tools/LightGlue/models/superpoint_v1.pth'))
model.eval()

extractor_path = 'export.onnx'

torch.onnx.export(
            model,
            image0[None],
            extractor_path,
            input_names=["image"],
            output_names=["keypoints", "scores", "descriptors"],
            opset_version=17,
            dynamic_axes={
            "keypoints": {1: "num_keypoints"},
            "scores": {1: "num_keypoints"},
            "descriptors": {1: "num_keypoints"},
        },
        )

Is there any idea to fix this error?