How to use a custom method for prediction using PyTorch JIT in a custom handler for production environment

I want to deploy a model (local at the first step) using Torchserve.

I trained my model, and saved it in eager mode

Model architecture:

- Encoder

Simple CNN model using the pretrained Inception v3.

- Decoder
A langauge model to generate caption(using LSTM cell) of an image.

- Encoder_to_Decoder
hocking up the two preceding models, contains two methods, the ordinary forward and caption_image for inference.

WHAT I WANT: I want to execute the caption_image method when inferencing using the given ‘tensorized’ when deploying into production.

Custom Handler (filename: custom_handler.py)

import json
import torch
import io
from PIL import Image
from ts.torch_handler.base_handler import BaseHandler
from torchvision import transforms


class CaptionHandler(BaseHandler):
    """
    A custom model handler implementation.
    The Handler takes a image as Tensor preprocess it
    and fed it to the network for inference.
    Then got a generated caption post-process it
    and sent it as a request response.
    """

    def __init__(self):
        super().__init__()
        self.initialized = False
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        with open("index_to_name.json") as f:
            self.vocabulary = json.load(f)

    def preprocess(self, request):
        request = request[0]
        image = request.get("data")
        if image is None:
            image = request.get("body")

        image = Image.open(io.BytesIO(image))
        transformed_image = self.transform(image)

        return transformed_image

    def inference(self, data):
        model_output = self.model(data) # get stuck here !
        return model_output

    def postprocess(self, data):
        return data

    def handle(self, data, context):
        model_input = self.preprocess(data)
        # print("model input", model_input)
        model_output = self.inference(model_input)
        return self.postprocess(model_output)

And use that custom handler(filename: caption_handler.py):

from custom_handler import CaptionHandler

_service = CaptionHandler()


def handle(data, context):
    """
    Handle inference requests to torchserve.
    :param data: supplied data
    :param context: the context of server
    :return: predicted caption
    """

    # if the handler not yet initialized with the context
    if not _service.initialized:
        _service.initialize(context)

    # Not data supplied
    if data is None:
        return None

    # preprocess, inference and postprocess our image
    data = _service.preprocess(data)
    data = _service.inference(data)
    data = _service.postprocess(data)

    return data

Then I packaged all my model artifacts using torch-model-archiver with this command:

torch-model-archiver --model-name caption --version 1.0 --serialized-file checkpoints/caption.pt --extra-files ./custom_handler.py,./index_to_name.json --handler caption_handler.py --export-path model_store --force

In the file logs/model_log.log I got:

MODEL_LOG - RuntimeError: forward() is missing value for argument 'captions'. Declaration: forward(__torch__.utils.models.Captioner self, Tensor images, Tensor captions) -> (Tensor)

Indeed, in the forward method in my model class, images and captions parameters are required to train the model, but while ‘inferencing’ I want to use the image_captioner method instead to make predictions and supply only the image as a parameter to my POST request.

curl http://127.0.0.1:8080/predictions/caption -T Data/Images/test/3637013_c675de7705.jpg

output:

{
  "code": 503,
  "type": "InternalServerException",
  "message": "Prediction failed"
}

Thank you,

ANY HELP WE’LL APPRECIATED GUYS.

Could you try filling captions with a dummy value? and then in your forward() function if that dummy variable is found make an inference while ignoring it