Shape Comparison exception in custom model trace

Custom model trace Code

  class WrappedDeeplabcustommodel(nn.Module):
  
      def __init__(self):
          super(WrappedDeeplabcustommodel, self).__init__()
          m = InstanceSegmentation(grid_size=32,
                                               factor=3,
                                               delta=0.72,
                                               pretrained=True,
                                               num_classes=23)
          m.load_state_dict(torch.load('./ckpt_root/49.pth'))
          self.m = m.eval()
  
      def forward(self, x):
          res = self.m(x)
          x = res[0]['masks']
          return x
  
  traceable_m = WrappedDeeplabcustommodel()
  with torch.no_grad():
      trace = torch.jit.trace(traceable_m, input_batch)

torch.jit.trace() does 3 forward passes and then gives out this error.

Comparison exception: expected tensor shape torch.Size([3, 4]) doesn’t match with actual tensor shape torch.Size([1])

This error sounds quite strange. Could you share a minimal, executable code snippet showing this error?
I assume that no model checkpoint loading is needed to see the shape mismatch?

The PyTorch version is: 1.8.0+cu101 and GPU is RTX 2080Ti.

The InstanceSegmentation() model is the MaskRCNN model shown in the tutorial here: TorchVision Object Detection Finetuning Tutorial — PyTorch Tutorials 2.1.1+cu121 documentation.

However, instead of binary classification, we are classifying multiple classes. The model runs and gives correct outputs without the torch.jit.trace.

However, there are two key differences in the model above compared to the tutorial:

  1. The InstanceSegmentation() inherits from nn.Module and is a wrapper for the get_model_instance_segmentation() in the tutorial
  2. ‘x’ in the WrappedDeeplabcustommodel() is a tensor instead of a list. That is, we send the tensor to InstanceSegmentation(). However, MaskRCNN in the tutorial requires a list as input. Thus, in the InstanceSegmentation() we convert the tensor to a list as follows:

images = list(image for image in x)

Finally, as per your suggestion we removed the m.load_state_dict(torch.load(‘./ckpt_root/49.pth’)) but still get the same error.

@ptrblck Please check the reply below in this thread from my teammate.

Could you post a minimal, executable code snippet reproducing the issue, please?

Below are the code snippets (3 files - convert, model, task). Please review.
Use an input.jpg (any input image)

Run convert.py to reproduce the issue.

model.py code:

import torch
import torch.nn as nn
from task import MaskRCNN
import torch.nn.functional as F

class InstanceSegmentation(nn.Module):
    def __init__(self, grid_size=32, factor=3, delta=0.72, pretrained=True, num_classes=0):
        super(InstanceSegmentation, self).__init__()
        self.grid_size = grid_size
        self.factor = factor
        self.delta = delta
        self.task_network = MaskRCNN(pretrained=pretrained, num_classes=num_classes)

    def forward(self, images, target=None):

        images = F.interpolate(images,
                               (self.grid_size*self.factor, self.grid_size*self.factor),
                               mode='bilinear',
                               align_corners=True)

        images = list(image for image in images)

        if target is not None:
            loss_dict = self.task_network(images, target)
        else:
            loss_dict = self.task_network(images)
        return loss_dict

convert.py code:


import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
import torch.nn as nn
import torchvision
import json
from model import InstanceSegmentation
from torchvision import transforms
from PIL import Image
from collections import namedtuple

import coremltools as ct
m = InstanceSegmentation(grid_size=32,
                                     factor=3,
                                     delta=0.72,
                                     pretrained=True,
                                     num_classes=23)

m = m.eval()
input_image = Image.open("input.jpg")
input_image.show()
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)

class WrappedDeeplabcustommodel(nn.Module):

    def __init__(self):
        super(WrappedDeeplabcustommodel, self).__init__()
        m = InstanceSegmentation(grid_size=32,
                                             factor=3,
                                             delta=0.72,
                                             pretrained=True,
                                             num_classes=23)
       
        self.m = m.eval()

    def forward(self, x):
        res = self.m(x)
        # Extract the tensor we want from the output dictionary
        x = res[0]['masks']
        return x
traceable_m = WrappedDeeplabcustommodel().eval()
with torch.no_grad():
    trace = torch.jit.trace(traceable_m, input_batch)

task.py file

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


def MaskRCNN(pretrained=True, num_classes=0):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256*4
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

@ptrblck sample code to reproduce the issue has been uploaded above. Please review.

@ptrblck Can you please check the issue?

After replacing the undefined input image with a random tensor I see:

WARNING:root:Torch version 1.12.0a0+gitba81899 has not been tested with coremltools. You may run into unexpected errors. Torch 1.10.2 is the most recent version that has been tested.
...
Traceback (most recent call last):
  File "convert.py", line 54, in <module>
    trace = torch.jit.trace(traceable_m, input_batch)
  File "/opt/pytorch/pytorch/torch/jit/_trace.py", line 750, in trace
    return trace_module(
  File "/opt/pytorch/pytorch/torch/jit/_trace.py", line 992, in trace_module
    _check_trace(
  File "/opt/pytorch/pytorch/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/pytorch/pytorch/torch/jit/_trace.py", line 535, in _check_trace
    raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
        Graph diff:
                  graph(%self.1 : __torch__.WrappedDeeplabcustommodel,
                        %x.1 : Tensor):
                    %m : __torch__.model.InstanceSegmentation = prim::GetAttr[name="m"](%self.1)
                    %4 : int = prim::Constant[value=1](), scope: __module.m/__module.m.task_network # /o
pt/pytorch/vision/torchvision/models/detection/generalized_rcnn.py:76:0
                    %5 : int = prim::Constant[value=2](), scope: __module.m/__module.m.task_network # /o
pt/pytorch/vision/torchvision/models/detection/generalized_rcnn.py:76:0
                    %6 : str = prim::Constant[value="constant"](), scope: __module.m/__module.m.task_net
work/__module.m.task_network.transform # /opt/pytorch/vision/torchvision/models/detection/transform.py:2
15:0
...

which seems to point to an unsupported coremltools usage and a failure in WrappedDeeplabcustommodel. Is this what you are seeing as well?

@ptrblck yes this is the error and it finally ends as below. Can you please help us understand what is the issue with WrappedDeeplabcustommodel? Thanks


/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/torch/nn/modules/module.py(860): _slow_forward
/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/torch/nn/modules/module.py(887): _call_impl
convert.py(48): forward
/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/torch/nn/modules/module.py(860): _slow_forward
/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/torch/nn/modules/module.py(887): _call_impl
/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/torch/jit/_trace.py(934): trace_module
/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/torch/jit/_trace.py(733): trace
convert.py(54):
Comparison exception: expected tensor shape torch.Size([3, 4]) doesn’t match with actual tensor shape torch.Size([1])!

My guess would be that this custom module uses data-dependent control flow and thus fails when trying to trace it. Could you try to torch.jit.script it instead?

Yes tried torch.jit.script anf plugged in below lines of code.

traceable_m = WrappedDeeplabcustommodel()
scripted_model = torch.jit.script(traceable_m)

mlmodel = ct.converters.convert(
  scripted_model,
  inputs=[ct.TensorType(shape=(1, 3, 512, 512))],
)

It gave the below error:

Traceback (most recent call last):
  File "bkp.py", line 58, in <module>
    mlmodel = ct.converters.convert(
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/_converters_entry.py", line 352, in convert
    mlmodel = mil_convert(
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 183, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 210, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 273, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 105, in __call__
    return load(*args, **kwargs)
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 46, in load
    converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols)
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 156, in __init__
    raw_graph, params_dict = self._expand_and_optimize_ir(self.torchscript)
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 456, in _expand_and_optimize_ir
    graph, params_dict = TorchConverter._jit_pass_lower_graph(graph, torchscript)
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 400, in _jit_pass_lower_graph
    _lower_graph_block(graph)
  File "/home/jakep/anaconda2/envs/grocery/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 379, in _lower_graph_block
    module = getattr(node_to_module_map[_input], attr_name)
KeyError: images.4 defined in (%images.4 : __torch__.torchvision.models.detection.image_list.ImageList, %targets.14 : Dict(str, Tensor)[]? = prim::TupleUnpack(%490)

Please suggests steps to convert MaskRCNN to .mlmodel format.

@ptrblck Please provide us an update to proceed further.

Sorry, I’m not familiar enough with coreml or the mlmodel format to be of any help, so you would have to wait for some CoreML experts.