Exporting maskrcnn_resnet50_fpn to (anything) for using in c++

Hey :slight_smile:

I really wanted to do inference in C++/CUDA, but was having a very difficult time getting my model to export to any format that would enable this. The ONNX export just kind of, well, clearly didn’t work correctly (always the same score, many warnings during the export process). The nvidia tools couldn’t handle it either. I finally got torch script working after looking around many places here on the forums and various issues and pieced something together that worked for me so just wanted to share back in case this maybe saves somebody else some time.

Disclaimer:

  1. I’m sure the ONNX / other issues were me being a doofus, new to this framework and all.
  2. Just because this worked for me doesn’t mean it’s the right way of doing it. Code that seemed to work with other users wasn’t working for me for unknown reasons. So if anything there is another option for you to try if you are stuck like I was.

It’s worth mentioning for training I had been following this tutorial: TorchVision Object Detection Finetuning Tutorial — PyTorch Tutorials 1.10.1+cu102 documentation

My code was a slightly older version of that tutorial, but get_model_instance_segmentation_model in the snippet below is not shown but it produces this

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

So other than that, and a lot of really useful tutorials elsewhere and finding a lot of issues resolved by some hard working pros getting MaskRCNN to be able to be scripted (thanks!!!), this was what I did. Hope it helps somebody like me who got realllly turned around trying to figure it out.

#!/usr/bin/env python3

import sys
from pathlib import Path

this_file_dir = Path(__file__).parent.absolute()
sys.path.insert(0, str(this_file_dir))

# See above, this just returns the model and the above path manipulation is just
# for me to be able to import it locally.  Create your model accordingly.
from some_local_file import get_instance_segmentation_model

import torch
from torch import nn
from torchvision.models.detection.mask_rcnn import MaskRCNN
from typing import Dict, List, Optional, Tuple
from torch import Tensor


# NOTE: see discussion here:
# https://github.com/facebookresearch/detr/issues/238#issuecomment-694772945
#
# This was created by trial and error, ultimately I want to wrap the model and
# return my desired outputs from forward.  The CXX code is using toTuple() not
# toTensor().  The main trick was finding out you need to save the internal
# model, and do self.model([inputs]) as a list with MaskRCNN.  Determining the
# output type depends on your application, but I got tuple working so I kept it.
class WrappedMaskRCNN(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    # NOTE: what you return here kind of depends on what you want to do on the
    # CXX side.  The nice thing about the wrapping setup is you get to change
    # this to your liking, noting that you should make sure the type annotations
    # on the python function signature match up!
    #
    # For me I was doing this in CXX (you don't have to `to(cpu_device)`):
    #
    #   auto output = module.forward(inputs).toTuple();
    #   const auto& elements = output->elements();
    #   const torch::Tensor& boxes = elements[0].toTensor().to(cpu_device);
    #   const torch::Tensor& labels = elements[1].toTensor().to(cpu_device);
    #   const torch::Tensor& scores = elements[2].toTensor().to(cpu_device);
    #   const torch::Tensor& masks = elements[3].toTensor().to(cpu_device);
    #
    # Another useful resource:
    # https://tebesu.github.io/posts/PyTorch-C++-Frontend
    def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        losses, detections = self.model([inputs], None)
        return (
            detections[0]["boxes"],
            detections[0]["labels"],
            detections[0]["scores"],
            detections[0]["masks"]
        )


if __name__ == "__main__":
    print("==> BEGIN")
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # our dataset has two classes only - background and person
    num_classes = 2

    # get the model using our helper function
    print("==> Get the model.")
    model = get_instance_segmentation_model(
        num_classes=num_classes,
        pretrained=True)
    model.training = False

    # (Optional, encouraged): load the checkpoint you've already trained to export.
    the_checkpoint = Path("/path/to/my/checkpoint.pt")
    print(f"==> Loading the model from '{the_checkpoint}'.")
    model.load_state_dict(
        torch.load(
            str(the_checkpoint),
        )["model_state_dict"]  # depends on how you saved things with torch.save
    )

    # NOTE: I don't know if this is necessary.  Just making sure it is in the right place.
    model = model.eval()
    model = model.to(device)

    wrapped_model = WrappedMaskRCNN(model)
    wrapped_model.training = False
    wrapped_model = wrapped_model.eval()
    wrapped_model = wrapped_model.to(device)

    # Set this to be the output filename you want to load in C++.
    output_f = this_file_dir / "output_jit_module.pt"
    print(f"==> Creating torch jit module {output_f}")

    # NOTE: want to make it stop warning about returning (Losses, Detections)?
    # It's good to know about and all, but I don't want that going on every time
    # I load the model in C++.  So just comment it out:
    # https://github.com/pytorch/vision/blob/afda28accbc79035384952c0359f0e4de8454cb3/torchvision/models/detection/generalized_rcnn.py#L107
    sm = torch.jit.script(wrapped_model)
    sm.save(str(output_f))

Like I said I’m far from an expert, have no idea why the onnx stuff wouldn’t work even though it seems to work for others, etc. If the above doesn’t work for you I’m sorry, but I also won’t be able to help you :confused: Just sharing what finally worked for me in hopes that it helps somebody else as a thanks to the pytorch community and the devs for making this possible :heart: It’s pretty damn cool how this stuff works :upside_down_face: