I have a detecton2 model. For inference API we were using flask but now switching to Kfserving. Kfserving required a saved model file and the class name to load the model.
Previously I use this code snippet to load the model.
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8 # set threshold for this model
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.01
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
model = build_model(cfg)
DetectionCheckpointer(model).load("model_final.pth")
model.train(False)
but Now I require something like this.
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
def forward(self, x):