We have trained a GPT2 model using the simpletransformers library and served with Flask on a EC2 instance with good results. From there, we decided to move to torchserve in order to improve the scalability and reliability when serving the model, but the quality of predictions is not as good as we expect or at least are not the same as the predictions that we get when serving with Flask. Has the same thing happened to someone?
The way we load the model, and the model files and dependencies are the same for both solutions.
In Flask we use something similar to this:
self.model = ConvAIModel("gpt2", self.model_directory, use_cuda=False)
response, new_history = self.model.interact_single(
message, history, personality=personality)
With Torchserve we have created a custom handler using the simpletranformers library:
import logging
import torch
from simpletransformers.conv_ai import ConvAIModel
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
class ConversationAIHandler(BaseHandler):
'''
Transformers text classifier handler class. This handler takes a text (string)
as an input and returns the classification text based on the serialized transformers checkpoint.
'''
def __init__(self):
super(ConversationAIHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
# Check if cuda (GPU) is available, if it is, use it
cuda_available = torch.cuda.is_available()
# Read model serialize/pt file
self.model = ConvAIModel("gpt2", model_dir, use_cuda=cuda_available)
self.initialized = True
def preprocess(self, data):
print(data)
input = data[0].get("body")
return input
def inference(self, model_input):
'''
Predict the answer of a text using a trained transformer model.
'''
personality = "['i like trains .', 'i like playing videogames .']"
response, history = self.model.interact_single(model_input["text"], model_input["history"], personality=personality)
return {"response": response, "history": history}
def postprocess(self, data):
return [data]
_service = ConversationAIHandler()
def handle(data, context):
try:
if not _service.initialized:
_service.initialize(context)
if data is None:
return None
data = _service.preprocess(data)
data = _service.inference(data)
data = _service.postprocess(data)
return data
except Exception as e:
raise e