How to trace external modules for model serialization

Hi. I wrote this question tagged as quantization, but the problem seems to be in how to save the model using torch.jit.save while containing external modules. Is this possible? I really appreciate any help you can provide.

There are two parts:

  • The official answer is that what you can do is to provide a custom operator in C++ (like eg torchvision does for eg nms) and then use that through torch.ops.mymodule.opname. This is compatible with the JIT. Including saving and loading.

  • The JIT has a Python fallback (if you tag a function @torch.jit.ignore and call that from your JITed function. This will let you trace a model, but you won’t be able to save it.

  • You could register a “stub” op and reflect that back to Python. Or write a little surgery helper to replace the Python fallbacks with that stub op before saving and change it back after loading.

Best regards

Thomas

Hi Thomas and thanks for your answer. I wonder if you can help me understand this use case:

  • I have a pure PyTorch nn.Module python class with a single call to an external library called opencv (for image processing)
  • I run torch.jit.trace(model) and save it
  • I look at the unzipped files from my saved model.pt file and I see that inside the torch.py there is no call to opencv and this makes sense to me because How would C++ lib pytorch will call each and every shared object like libopencv.so during inference?
  • Is my understanding correct? If so, how do ask torch.jit.trace() to return an error or warning about my code calling opencv library?

Here is the code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
class Net(nn.Module):
    def __init__(self, image_size:int=28):
        self.image_size = image_size
        n_linear1 = 64 * int((image_size-4)/2) * int((image_size-4)/2)
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(n_linear1, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        # Run some algo that makes torch.jit.trace into "break" mode because of the x.tolist() call
        x = x.tolist()
        x = np.array(x, dtype=np.float32).reshape((self.image_size,self.image_size))
        kernel = np.ones((5,5),np.float32)/25.0
        x = cv2.filter2D(x,-1,kernel)
        x = np.array(x).reshape((-1,1,self.image_size,self.image_size))
        x = torch.Tensor(x)
        # from here, all is pure Torch code
        return F.log_softmax(self.fc2(self.dropout2(F.relu(self.fc1(torch.flatten(self.dropout1(F.max_pool2d(F.relu(self.conv2(F.relu(self.conv1(x)))), 2)),1))))),dim=1)

After jit tracing, here is the main generated python file:

class Net(Module):
  __parameters__ = []
  __buffers__ = []
  training : bool
  _is_full_backward_hook : Optional[bool]
  conv1 : __torch__.torch.nn.modules.conv.Conv2d
  conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d
  dropout1 : __torch__.torch.nn.modules.dropout.Dropout
  dropout2 : __torch__.torch.nn.modules.dropout.___torch_mangle_1.Dropout
  fc1 : __torch__.torch.nn.modules.linear.Linear
  fc2 : __torch__.torch.nn.modules.linear.___torch_mangle_2.Linear
  def forward(self: __torch__.Net,
    x: Tensor) -> Tensor:
    fc2 = self.fc2
    dropout2 = self.dropout2
    fc1 = self.fc1
    dropout1 = self.dropout1
    conv2 = self.conv2
    conv1 = self.conv1
    input = torch.to(torch.lift_fresh(CONSTANTS.c0), torch.device("cpu"), 6)
    input0 = torch.relu((conv1).forward(input, ))
    input1 = torch.relu((conv2).forward(input0, ))
    input2 = torch.max_pool2d(input1, [2, 2], annotate(List[int], []), [0, 0], [1, 1])
    input3 = torch.flatten((dropout1).forward(input2, ), 1)
    input4 = torch.relu((fc1).forward(input3, ))
    _0 = (fc2).forward((dropout2).forward(input4, ), )
    return torch.log_softmax(_0, 1)