Below are the code snippets (3 files - convert, model, task). Please review.
Use an input.jpg (any input image)
Run convert.py to reproduce the issue.
model.py code:
import torch
import torch.nn as nn
from task import MaskRCNN
import torch.nn.functional as F
class InstanceSegmentation(nn.Module):
def __init__(self, grid_size=32, factor=3, delta=0.72, pretrained=True, num_classes=0):
super(InstanceSegmentation, self).__init__()
self.grid_size = grid_size
self.factor = factor
self.delta = delta
self.task_network = MaskRCNN(pretrained=pretrained, num_classes=num_classes)
def forward(self, images, target=None):
images = F.interpolate(images,
(self.grid_size*self.factor, self.grid_size*self.factor),
mode='bilinear',
align_corners=True)
images = list(image for image in images)
if target is not None:
loss_dict = self.task_network(images, target)
else:
loss_dict = self.task_network(images)
return loss_dict
convert.py code:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import torch
import torch.nn as nn
import torchvision
import json
from model import InstanceSegmentation
from torchvision import transforms
from PIL import Image
from collections import namedtuple
import coremltools as ct
m = InstanceSegmentation(grid_size=32,
factor=3,
delta=0.72,
pretrained=True,
num_classes=23)
m = m.eval()
input_image = Image.open("input.jpg")
input_image.show()
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
class WrappedDeeplabcustommodel(nn.Module):
def __init__(self):
super(WrappedDeeplabcustommodel, self).__init__()
m = InstanceSegmentation(grid_size=32,
factor=3,
delta=0.72,
pretrained=True,
num_classes=23)
self.m = m.eval()
def forward(self, x):
res = self.m(x)
# Extract the tensor we want from the output dictionary
x = res[0]['masks']
return x
traceable_m = WrappedDeeplabcustommodel().eval()
with torch.no_grad():
trace = torch.jit.trace(traceable_m, input_batch)
task.py file
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def MaskRCNN(pretrained=True, num_classes=0):
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256*4
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
hidden_layer,
num_classes)
return model