Torch serve handler doesnt not load state_dict in Docker

I have a saved model which i am trying to serve it with Docker. I can load the stat dicts locally and in Docker if i simply call the trainer_model.pt using load_state_dict() it works.

However executing it through the handler doesnt seem to we working.

Here is how things are set up.

hander.py

handler which torch-model-archiver needs to read the states.

import sys
import importlib
from pathlib import Path
sys.path.append("/home/model-server")

import torch
from ts.torch_handler.base_handler import BaseHandler
from ts.utils.util import list_classes_from_module
import torch.nn.functional as FN

print('\n-----Starting handler -----\n')

class DefaultClassifier(BaseHandler):

    def _load_pickeled_model(self, model_dir, model_file, model_pt_path):
        """Load a pickeled model"""

        print(f"Model dir is \n-----> {model_dir} ---------\n")
        print(f"Model file is \n-----> {model_file} ---------\n")


        model_dir = Path(model_dir)
        model_file_path = model_dir / model_file

        if not model_file_path.exists():
            raise RuntimeError(f"Model file {model_file_path} not found")

        module = importlib.import_module(model_file.split(".")[0])
        class_defs_in_module = list_classes_from_module(module)
        if len(class_defs_in_module) != 1:
            raise ValueError(f"Expected one model class, got {class_defs_in_module}")

        model_class = class_defs_in_module[0]
        model = model_class()

        if model_pt_path:
            print(f"Model pt path is \n-----> {model_pt_path} ---------\n")
            model.load_state_dict(torch.load(model_pt_path)['model'], strict=False)

        return model

Docker file

Container in which the model is serving.

FROM pytorch/torchserve:0.4.2-cpu

USER root
RUN printf "\nservice_envelope=json" >> /home/model-server/config.properties \
    && pip install pytorch-lightning torchmetrics google-cloud-storage scikit-learn
USER model-server


COPY src /home/model-server/src
COPY load_model.py /home/model-server/src
COPY configs /home/model-server/configs
COPY /models/trained-model-2021-11-23--23:08-hs.pt /home/model-server/src/

RUN torch-model-archiver \
    --model-name=ponzi \
    --version=0.1 \
    --model-file=/home/model-server/src/model.py \
    --serialized-file=/home/model-server/src/trained-model-2021-11-23--23:08-hs.pt \
    --handler=/home/model-server/src/handler.py \
    --export-path=/home/model-server/model-store \
    --extra-files=/home/model-server/configs/features.yml,/home/model-server/configs/model_configs.yml,/home/model-server/src/utils.py

CMD ["torchserve", \
     "--start", \
     "--ts-config=/home/model-server/config.properties", \
     "--models ponzi=ponzi.mar"]

For some reason hander starts but never really calls _load_pickeled_model as i see it from the logs.

2021-11-24 10:28:14,156 [INFO ] W-9008-ponzi_v20211122-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9008.
2021-11-24 10:28:14,169 [INFO ] W-9006-ponzi_v20211122-stdout MODEL_LOG - model_name: ponzi, batchSize: 1
2021-11-24 10:28:14,170 [INFO ] W-9008-ponzi_v20211122-stdout MODEL_LOG - model_name: ponzi, batchSize: 1
2021-11-24 10:28:14,171 [INFO ] W-9006-ponzi_v20211122-stdout MODEL_LOG - 
2021-11-24 10:28:14,171 [INFO ] W-9006-ponzi_v20211122-stdout MODEL_LOG - -----Starting handler -----
2021-11-24 10:28:14,172 [INFO ] W-9008-ponzi_v20211122-stdout MODEL_LOG - 
2021-11-24 10:28:14,172 [INFO ] W-9008-ponzi_v20211122-stdout MODEL_LOG - -----Starting handler -----
2021-11-24 10:28:14,203 [INFO ] W-9009-ponzi_v20211122-stdout MODEL_LOG - model_name: ponzi, batchSize: 1
2021-11-24 10:28:14,204 [INFO ] W-9009-ponzi_v20211122-stdout MODEL_LOG - 
2021-11-24 10:28:14,204 [INFO ] W-9009-ponzi_v20211122-stdout MODEL_LOG - -----Starting handler -----
2021-11-24 10:28:15,323 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG - 
2021-11-24 10:28:15,323 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG - Backend worker process died.
2021-11-24 10:28:15,323 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG - Traceback (most recent call last):
2021-11-24 10:28:15,323 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 183, in <module>
2021-11-24 10:28:15,323 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -     worker.run_server()
2021-11-24 10:28:15,323 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 155, in run_server
2021-11-24 10:28:15,324 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -     self.handle_connection(cl_socket)
2021-11-24 10:28:15,324 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 117, in handle_connection
2021-11-24 10:28:15,324 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -     service, result, code = self.load_model(msg)
2021-11-24 10:28:15,324 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 90, in load_model
2021-11-24 10:28:15,324 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -     service = model_loader.load(model_name, model_dir, handler, gpu, batch_size, envelope)
2021-11-24 10:28:15,325 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -   File "/usr/local/lib/python3.6/dist-packages/ts/model_loader.py", line 110, in load
2021-11-24 10:28:15,325 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -     initialize_fn(service.context)
2021-11-24 10:28:15,325 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -   File "/usr/local/lib/python3.6/dist-packages/ts/torch_handler/base_handler.py", line 66, in initialize
2021-11-24 10:28:15,325 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -     self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
2021-11-24 10:28:15,325 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -   File "/usr/local/lib/python3.6/dist-packages/ts/torch_handler/base_handler.py", line 130, in _load_pickled_model
2021-11-24 10:28:15,325 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -     model.load_state_dict(state_dict)
2021-11-24 10:28:15,326 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -   File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
2021-11-24 10:28:15,326 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG -     self.__class__.__name__, "\n\t".join(error_msgs)))
2021-11-24 10:28:15,326 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG - RuntimeError: Error(s) in loading state_dict for LitPonzi:
2021-11-24 10:28:15,326 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG - 	Missing key(s) in state_dict: "layer_1.weight", "layer_1.bias", "layer_3.weight", "layer_3.bias". 
2021-11-24 10:28:15,326 [INFO ] W-9010-ponzi_v20211122-stdout MODEL_LOG - 	Unexpected key(s) in state_dict: "model". 
2021-11-24 10:28:15,327 [INFO ] epollEventLoopGroup-5-13 org.pytorch.serve.wlm.WorkerThread - 9010 Worker disconnected. WORKER_STARTED
2021-11-24 10:28:15,328 [DEBUG] W-9010-ponzi_v20211122 org.pytorch.serve.wlm.WorkerThread - System state is : WORKER_STARTED
2021-11-24 10:28:15,328 [DEBUG] W-9010-ponzi_v20211122 org.pytorch.serve.wlm.WorkerThread - Backend worker monitoring thread interrupted or backend worker process died.

I have tested loading the state_dict as standalone in the container and that works. For example the below works

import torch
from src.model import LitPonzi

# test to be run in docker
loaded_model = LitPonzi()
loaded_model.load_state_dict(torch.load('src/trained-model-2021-11-23--23:08-hs.pt')['model'], strict=False)
loaded_model.eval()
print(f"Loaded model, {loaded_model}")

Any idea what i am doing wrong here.? Appreciate your help.

I think the fix is just jit.load() here’s an examplee you can follow that does this