Hi everyone,
I’m currently trying to export Sam (Segment Anything model) with pytorch to use it in C++. However, it seems that it doesn’t succeed managing other classes. I saw many topics but I’m still unable to understand why it couldn’t recognize Sam type.
what I’m trying to do:
from segment_anything import SamPredictor, sam_model_registry
import torch
sam = sam_model_registry["default"](checkpoint="sam_vit_h_4b8939.pth")
sam.to("cuda")
predictor = SamPredictor(sam)
sm = torch.jit.script(predictor)
sm.save("sam.pt")
The error :
Traceback (most recent call last):
File "my_torch_project_folder_path\from_py_to_cpp.py", line 9, in <module>
sm = torch.jit.script(predictor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "my_torch_project_folder_path\.venv\Lib\site-packages\torch\jit\_script.py", line 1405, in script
return torch.jit._recursive.create_script_class(obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "my_torch_project_folder_path\.venv\Lib\site-packages\torch\jit\_recursive.py", line 525, in create_script_class
_compile_and_register_class(type(obj), rcb, qualified_class_name)
File "my_torch_project_folder_path\.venv\Lib\site-packages\torch\jit\_recursive.py", line 61, in _compile_and_register_class
script_class = torch._C._jit_script_class_compile(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError:
Unknown type name 'Sam':
File "my_torch_project_folder_path\.venv\Lib\site-packages\segment_anything\predictor.py", line 20
def __init__(
self,
sam_model: Sam,
~~~ <--- HERE
) -> None:
"""
The constructor in the class
import numpy as np
import torch
from segment_anything.modeling import Sam
from typing import Optional, Tuple
from .utils.transforms import ResizeLongestSide
class SamPredictor:
def __init__(
self,
sam_model: Sam,
) -> None:
"""
Uses SAM to calculate the image embedding for an image, and then
allow repeated, efficient mask prediction given prompts.
Arguments:
sam_model (Sam): The model to use for mask prediction.
"""
super().__init__()
self.model = sam_model
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
self.reset_image()
Does anyone have any explanation or solution for me? Do I have to export to the onnx format and use onnxruntime library to do this kind of stuff ? torch.jit.script seemed to be a pretty good solution to what I wanted to do. I feel like I’m missing out on something important and it is hard for me to give up.
Thank you very much for your help.