Loading TorchScript Module : class method not recognized during compilation

My objective is to use the following class and script, derived from this excellent work: https://github.com/timesler/facenet-pytorch :

from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import pandas as pd
import os

workers = 0 if os.name == 'nt' else 4

def collate_fn(x):
    return x[0]

def describe(x):
    print("Type: {}".format(x.type()))
    print("Shape/size: {}".format(x.shape))
    print("Values: \n{}".format(x))

class GetFaceEmbedding(torch.nn.Module):
    def __init__(self):
        super(GetFaceEmbedding, self).__init__()

    @classmethod
    def getFaceEmbedding(self, imagePath):

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        print('Running on device: {}'.format(device))

        mtcnn = MTCNN(
            image_size=160, margin=0, min_face_size=20,
            thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
            device=device
        )
        resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
        dataset = datasets.ImageFolder(imagePath)
        dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
        loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)
        aligned = []
        names = []
        for x, y in loader:
            x_aligned, prob = mtcnn(x, return_prob=True)
            if x_aligned is not None:
                print('Face detected with probability: {:8f}'.format(prob))
                aligned.append(x_aligned)
                names.append(dataset.idx_to_class[y])
        aligned = torch.stack(aligned).to(device)
        embeddings = resnet(aligned).detach().cpu()
        return embeddings

With python it works fine:

(venv373) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ python3 .  
/getFaceEmbedding-01.py 
Running on device: cpu
Face detected with probability: 0.999430
Type: torch.FloatTensor
Shape/size: torch.Size([1, 512])
Values: 
tensor([[ 3.6307e-02, -8.8092e-02, -3.5002e-02, -8.2932e-02,  1.9032e-02,
          2.3228e-02,  2.4253e-02, -3.7844e-02, -6.8906e-02,  2.0351e-02,
         -6.7093e-02,  3.6181e-02, -2.5933e-02, -6.0015e-02,  2.6653e-02,
          9.4335e-02, -2.9241e-02, -2.8357e-02,  7.2207e-02, -3.7747e-02,
          6.3515e-03, -3.0220e-02, -2.4530e-02,  1.0004e-01,  6.6520e-02,
          ....
          3.2497e-02,  2.3421e-02, -5.3921e-02,  1.9589e-02, -2.8655e-03,
          1.3474e-02, -2.2743e-02,  3.2976e-02, -5.6658e-02,  2.0837e-02,
         -4.7152e-02, -6.5534e-02]])

Following the indications found here: https://pytorch.org/tutorials/advanced/cpp_export.html

I added to getFaceEmbedding.py these lines:

my_module = GetFaceEmbedding()
sm = torch.jit.script(my_module)
sm.save("annotated_get_face_embedding.pt")

I then saved the serialized file:

(venv373) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ python3 
./getFaceEmbedding.py

-rw-r--r-- 1 marco marco 1,4K mar 19 18:52 annotated_get_face_embedding.pt

And created this cpp file:

(venv373) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ nano 

faceEmbedding.cpp :

#include <torch/script.h>
#include <iostream>
#include <memory>
#include <filesystem>

int main(int argc, const char* argv[]) {
    //if(argc != 3) {
    //    std::cerr << "usage:usage: faceEmbedding <path-to-exported-script-module> <path-to-image-f 
ile> \n";
    //    return -1;
    //}

  torch::jit::script::Module module;
  try {
      // Deserialize the ScriptModule from a file using torch::jit::load().
      module = torch::jit::load(argv[1]);
      std::filesystem::path imgPath = argv[2];

      // Execute the model and turn its output into a tensor
      at::Tensor output = module.getFaceEmbedding(imgPath).ToTensor();
  }
  catch (const c10::Error& e) {
      std::cerr << "error loading the model\n";
      return -1;
  }
  std::cout << "ok\n";
} // end of main() function

But during the compilation phase I get this error :
"struct torch::jit::script::Module’ has no member named ‘getFaceEmbedding’ "

(venv373) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ mkdir build
(venv373) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ cd build
(venv373) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples/build$ cmake 
-DCMAKE_PREFIX_PATH=/home/marco/PyTorchMatters/libtorch ..
-- The C compiler identification is GNU 9.2.1
-- The CXX compiler identification is GNU 9.2.1
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE  
-- Found torch: /home/marco/PyTorchMatters/libtorch/lib/libtorch.so  
-- Configuring done
-- Generating done
-- Build files have been written to: /home/marco/PyTorchMatters/facenet_pytorch/examples/build
(venv373) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples/build$ cmake --build . 
--config Release
Scanning dependencies of target faceEmbedding
[ 50%] Building CXX object CMakeFiles/faceEmbedding.dir/faceEmbedding.cpp.o
/home/marco/PyTorchMatters/facenet_pytorch/examples/faceEmbedding.cpp: In function ‘int 
main(int, const char**)’:
/home/marco/PyTorchMatters/facenet_pytorch/examples/faceEmbedding.cpp:20:34: error: ‘struct 
torch::jit::script::Module’ has no member named ‘getFaceEmbedding’
   20 |       at::Tensor output = module.getFaceEmbedding(imgPath).ToTensor();
      |                                  ^~~~~~~~~~~~~~~~
CMakeFiles/faceEmbedding.dir/build.make:62: recipe for target 'CMakeFiles/faceEmbedding.dir
/faceEmbedding.cpp.o' failed
make[2]: *** [CMakeFiles/faceEmbedding.dir/faceEmbedding.cpp.o] Error 1
CMakeFiles/Makefile2:75: recipe for target 'CMakeFiles/faceEmbedding.dir/all' failed
make[1]: *** [CMakeFiles/faceEmbedding.dir/all] Error 2
Makefile:83: recipe for target 'all' failed
make: *** [all] Error 2

How to solve the problem?
Looking forward to your kind help.
Marco

I modified the method as follows:

def forward(self, imagePath):

Now I get this error:
" torch.jit.frontend.UnsupportedNodeError: DictComp aren’t supported"

  (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ python3 
  ./getFaceEmbedding.py 
  Traceback (most recent call last):
    File "./getFaceEmbedding.py", line 127, in <module>
      sm = torch.jit.script(my_module)
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1255, in 
  script
      return torch.jit._recursive.recursive_script(obj)
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 534, in 
  recursive_script
      return create_script_module(nn_module, infer_methods_to_compile(nn_module))
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 493, in 
  infer_methods_to_compile
      stubs.append(make_stub_from_method(nn_module, method))
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 40, in 
  make_stub_from_method
      return make_stub(func)
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 33, in 
  make_stub
      ast = torch.jit.get_jit_def(func, self_name="RecursiveScriptModule")
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 171, in 
  get_jit_def
      return build_def(ctx, py_ast.body[0], type_line, self_name)
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 212, in 
  build_def
      build_stmts(ctx, body))
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 127, in 
  build_stmts
      stmts = [build_stmt(ctx, s) for s in stmts]
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 127, in 
  <listcomp>
      stmts = [build_stmt(ctx, s) for s in stmts]
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 187, in 
  __call__
      return method(ctx, node)
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 289, in 
  build_Assign
      rhs = build_expr(ctx, stmt.value)
    File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 186, in 
  __call__
      raise UnsupportedNodeError(ctx, node)
  torch.jit.frontend.UnsupportedNodeError: DictComp aren't supported:
    File "./getFaceEmbedding.py", line 79
          dataset = datasets.ImageFolder(imagePath)
          dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
                                 ~ <--- HERE
          loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)

I read here: https://pytorch.org/tutorials/advanced/cpp_export.html
that “If you need to exclude some methods in your nn.Module because they use Python features that TorchScript doesn’t support yet, you could annotate those with @torch.jit.ignore
But I guess that excluding the dictionary comprehension from the serialization, will interfere with the method’s functionality.
What would you suggest me to do?

I refactored everything as follows:

from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import pandas as pd
import os

workers = 0 if os.name == 'nt' else 4

def collate_fn(x):
    return x[0]

def describe(x):
    print("Type: {}".format(x.type()))
    print("Shape/size: {}".format(x.shape))
    print("Values: \n{}".format(x))

def dataLoader(imagePath):
    dataset = datasets.ImageFolder(imagePath)
    dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
    loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)
    return (dataset, loader)

class GetFaceEmbedding(torch.nn.Module):
    def __init__(self):
        super(GetFaceEmbedding, self).__init__()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        print('Running on device: {}'.format(self.device))

        self.mtcnn = MTCNN(
            image_size=160, margin=0, min_face_size=20,
            thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
            device=self.device
        )

        self.resnet = InceptionResnetV1(pretrained='vggface2').eval().to(self.device)


    def forward(self, imagePath):
        dataset, loader = dataLoader(imagePath)

        aligned = []
        names = []
        for x, y in loader:
            x_aligned, prob = self.mtcnn(x, return_prob=True)
            if x_aligned is not None:
                print('Face detected with probability: {:8f}'.format(prob))
                aligned.append(x_aligned)
                names.append(dataset.idx_to_class[y])

        aligned = torch.stack(aligned).to(self.device)
        embeddings = self.resnet(aligned).detach().cpu()
        return embeddings


my_module = GetFaceEmbedding()
sm = torch.jit.script(my_module)
sm.save("annotated_get_face_embedding.pt")

With python it works fine:

(opencv4) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ python3 .   
/getFaceEmbedding.py
Running on device: cpu
Face detected with probability: 0.999430
Type: torch.FloatTensor
Shape/size: torch.Size([1, 512])
Values: 
tensor([[ 3.6307e-02, -8.8092e-02, -3.5002e-02, -8.2932e-02,  1.9032e-02,
          2.3228e-02,  2.4253e-02, -3.7844e-02, -6.8906e-02,  2.0351e-02,
          ....
          1.3474e-02, -2.2743e-02,  3.2976e-02, -5.6658e-02,  2.0837e-02,
         -4.7152e-02, -6.5534e-02]])

But when trying to serialize I get this new error:
“torch.jit.frontend.UnsupportedNodeError: with statements aren’t supported”
which I do not understand where does it come from

(opencv4) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ python3   
./getFaceEmbedding.py
Running on device: cpu
Traceback (most recent call last):
  File "./getFaceEmbedding.py", line 127, in <module>
    sm = torch.jit.script(my_module)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1255, in 
script
    return torch.jit._recursive.recursive_script(obj)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 534, in 
recursive_script
    return create_script_module(nn_module, infer_methods_to_compile(nn_module))
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 296, in 
create_script_module
    return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 336, in 
create_script_module_impl
     script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1593, in 
_construct
    init_fn(script_module)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 328, in 
init_fn
    scripted = recursive_script(orig_value)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 534, in 
recursive_script
    return create_script_module(nn_module, infer_methods_to_compile(nn_module))
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 493, in 
infer_methods_to_compile
    stubs.append(make_stub_from_method(nn_module, method))
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 40, in 
make_stub_from_method
    return make_stub(func)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 33, in 
make_stub
    ast = torch.jit.get_jit_def(func, self_name="RecursiveScriptModule")
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 171, in 
get_jit_def
    return build_def(ctx, py_ast.body[0], type_line, self_name)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 212, in 
build_def
    build_stmts(ctx, body))
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 127, in 
build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 127, in 
<listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 186, in 
__call__
    raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: with statements aren't supported:
  File "/home/marco/anaconda3/lib/python3.7/site-packages/facenet_pytorch/models/mtcnn.py", 
line 246

        # Detect faces
        with torch.no_grad():
        ~~~~ <--- HERE
            batch_boxes, batch_probs = self.detect(img)

The error is apparently raised in this line of code from MTCNN.
I’m not familiar with the model implementation, but you could try to remove the with statement or replace it with a decorator.

As far as I understand this line of code

    with torch.no_grad():

means “apply to the execution of the next line of code”:

        batch_boxes, batch_probs = self.detect(img)

“the runtime context of torch.no_grad(), which puts all requires_grad flag to false, meaning we don’t want that PyTorch calculates the gradients of the new defined variables batch_boxes and batch_probs”.
Is my interpretation correct or am I saying something wrong?

Here: https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients in the code example it is actually used the with statement:

print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Or it used detached to get a new Tensor with the same content that does not require gradients:

print(x.requires_grad)
y = x.detach()
print(y.requires_grad)
print(x.eq(y).all())

I tried to modify in
/home/marco/anaconda3/lib/python3.7/site-packages/facenet_pytorch/models/mtcnn.py

the lines of code:

    with torch.no_grad():
    batch_boxes, batch_probs = self.detect(img)

to:

    #with torch.no_grad():
    batch_boxes, batch_probs = self.detect(img)
    batch_boxes = batch_boxes.detach()
    batch_probs = batch_probs.detach()

But I get this error message:

(opencv4) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ python3 
./getFaceEmbedding.py
Running on device: cpu
Traceback (most recent call last):
  File "./getFaceEmbedding.py", line 126, in <module>
    sm = torch.jit.script(my_module)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1255, in script
    return torch.jit._recursive.recursive_script(obj)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 534, in 
recursive_script
    return create_script_module(nn_module, infer_methods_to_compile(nn_module))
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 296, in 
create_script_module
    return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 336, in 
create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1593, in 
_construct
    init_fn(script_module)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 328, in init_fn
    scripted = recursive_script(orig_value)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 534, in 
recursive_script
    return create_script_module(nn_module, infer_methods_to_compile(nn_module))
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 296, in 
create_script_module
    return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 340, in 
create_script_module_impl
    create_methods_from_stubs(concrete_type, stubs)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 259, in 
create_methods_from_stubs
    concrete_type._create_methods(defs, rcbs, defaults)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 570, in 
compile_unbound_method
    stub = make_stub(fn)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 33, in 
make_stub
    ast = torch.jit.get_jit_def(func, self_name="RecursiveScriptModule")
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 171, in 
get_jit_def
    return build_def(ctx, py_ast.body[0], type_line, self_name)
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 212, in 
build_def
    build_stmts(ctx, body))
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 127, in 
build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 127, in 
<listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/marco/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 186, in 
__call__
    raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: with statements aren't supported:
  File "/home/marco/anaconda3/lib/python3.7/site-packages/facenet_pytorch/models/mtcnn.py", line 
346
        """

        with torch.no_grad():
        ~~~~ <--- HERE
            batch_boxes, batch_points = detect_face(
                img, self.min_face_size,

Python decorators, applying a function to another function, could be of help…but how?
Neither my attempts were correct:

First attempt:

    @torch.no_grad
    batch_boxes, batch_probs = self.detect(img)


(opencv4) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ python3 
./getFaceEmbedding.py
Traceback (most recent call last):
  File "./getFaceEmbedding.py", line 16, in <module>
    from facenet_pytorch import MTCNN, InceptionResnetV1
  File "/home/marco/anaconda3/lib/python3.7/site-packages/facenet_pytorch/__init__.py", line 2, in
 <module>
    from .models.mtcnn import MTCNN, PNet, RNet, ONet, prewhiten, fixed_image_standardization
  File "/home/marco/anaconda3/lib/python3.7/site-packages/facenet_pytorch/models/mtcnn.py", line 
248
    batch_boxes, batch_probs = self.detect(img)
              ^
SyntaxError: invalid syntax

Second attempt:

batch_boxes, batch_probs = self.detect(img)

@torch.no_grad
def detect(self, img, landmarks=False):

(opencv4) (base) marco@pc01:~/PyTorchMatters/facenet_pytorch/examples$ python3 . 
/getFaceEmbedding.py
Traceback (most recent call last):
  File "./getFaceEmbedding.py", line 16, in <module>
    from facenet_pytorch import MTCNN, InceptionResnetV1
  File "/home/marco/anaconda3/lib/python3.7/site-packages/facenet_pytorch/__init__.py", line 2, in 
<module>
    from .models.mtcnn import MTCNN, PNet, RNet, ONet, prewhiten, fixed_image_standardization
  File "/home/marco/anaconda3/lib/python3.7/site-packages/facenet_pytorch/models/mtcnn.py", line 
157, in <module>
    class MTCNN(nn.Module):
  File "/home/marco/anaconda3/lib/python3.7/site-packages/facenet_pytorch/models/mtcnn.py", line 
308, in MTCNN
    def detect(self, img, landmarks=False):
TypeError: no_grad() takes no arguments

I’m not a Python expert… so… how would you use the decorator to apply the no_grad() runtime context to detect() method?