Hello everybody,
When I try to export pth file to onnx, I get an error:
Traceback (most recent call last):
File "/home/bzeren/projects/visual_tools/LightGlue/torch2onnx.py", line 27, in <module>
torch.onnx.export(
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
outs = ONNXTracedModule(
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 133, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/jit/_trace.py", line 124, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/bzeren/projects/visual_tools/LightGlue/lightglue/superpoint.py", line 153, in forward
assert key in data, f"Missing key {key} in data"
File "/home/bzeren/projects/visual_tools/LightGlue/venv/lib/python3.10/site-packages/torch/_tensor.py", line 1061, in __contains__
raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.
Here is my code snippet:
import numpy as np
import torch
import torch.onnx
from lightglue import LightGlue, SuperPoint
from lightglue.utils import load_image, rbd
from lightglue import viz2d
image0 = load_image("/home/bzeren/projects/visual_tools/LightGlue/images/1.jpg")
image1 = load_image("/home/bzeren/projects/visual_tools/LightGlue/images/2.jpg")
# -*- SuperPoint -*-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SuperPoint(max_num_keypoints=2048)
model.load_state_dict(torch.load('/home/bzeren/projects/visual_tools/LightGlue/models/superpoint_v1.pth'))
model.eval()
extractor_path = 'export.onnx'
torch.onnx.export(
model,
image0[None],
extractor_path,
input_names=["image"],
output_names=["keypoints", "scores", "descriptors"],
opset_version=17,
dynamic_axes={
"keypoints": {1: "num_keypoints"},
"scores": {1: "num_keypoints"},
"descriptors": {1: "num_keypoints"},
},
)
Is there any idea to fix this error?