Export Segment Anything Model to C++

Hi everyone,
I’m currently trying to export Sam (Segment Anything model) with pytorch to use it in C++. However, it seems that it doesn’t succeed managing other classes. I saw many topics but I’m still unable to understand why it couldn’t recognize Sam type.

what I’m trying to do:

from segment_anything import SamPredictor, sam_model_registry
import torch


sam = sam_model_registry["default"](checkpoint="sam_vit_h_4b8939.pth")
sam.to("cuda")
predictor = SamPredictor(sam)

sm = torch.jit.script(predictor)
sm.save("sam.pt")

The error :

Traceback (most recent call last):
  File "my_torch_project_folder_path\from_py_to_cpp.py", line 9, in <module>
    sm = torch.jit.script(predictor)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "my_torch_project_folder_path\.venv\Lib\site-packages\torch\jit\_script.py", line 1405, in script
    return torch.jit._recursive.create_script_class(obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "my_torch_project_folder_path\.venv\Lib\site-packages\torch\jit\_recursive.py", line 525, in create_script_class
    _compile_and_register_class(type(obj), rcb, qualified_class_name)
  File "my_torch_project_folder_path\.venv\Lib\site-packages\torch\jit\_recursive.py", line 61, in _compile_and_register_class
    script_class = torch._C._jit_script_class_compile(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError:
Unknown type name 'Sam':
  File "my_torch_project_folder_path\.venv\Lib\site-packages\segment_anything\predictor.py", line 20
    def __init__(
        self,
        sam_model: Sam,
                   ~~~ <--- HERE
    ) -> None:
        """

The constructor in the class

import numpy as np
import torch

from segment_anything.modeling import Sam

from typing import Optional, Tuple

from .utils.transforms import ResizeLongestSide


class SamPredictor:
    def __init__(
        self,
        sam_model: Sam,
    ) -> None:
        """
        Uses SAM to calculate the image embedding for an image, and then
        allow repeated, efficient mask prediction given prompts.

        Arguments:
          sam_model (Sam): The model to use for mask prediction.
        """
        super().__init__()
        self.model = sam_model
        self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
        self.reset_image()

Does anyone have any explanation or solution for me? Do I have to export to the onnx format and use onnxruntime library to do this kind of stuff ? torch.jit.script seemed to be a pretty good solution to what I wanted to do. I feel like I’m missing out on something important and it is hard for me to give up.

Thank you very much for your help.

Just after I made this post, I looked again at my old research and went through this answer: Why can't I use classes for type annotation of the arguments to a function under the `torch.jit.script` decorator? - Stack Overflow
It’s seems that it’s not possible to use his own class than those that are written there : TorchScript Language Reference — PyTorch 2.3 documentation

Hi William, did you get SAM working in C++?