Hello everyone, I’ve trained a bird identification model and want to deploy it on mobile apps. So I converted it to torchscript. However, the jit model does not work correctly as the original model.
Here is the model, model20240824.pth
is the weight dict, and birdinfo.json
is the label map.
Here is my code used for export jit (my result):
import torch
from torch import nn
from torchvision.models import resnet34
from torch.utils.mobile_optimizer import optimize_for_mobile
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet34()
model.fc = nn.Linear(model.fc.in_features, 11000)
model.load_state_dict(torch.load('model20240824.pth'))
model = model.to(device)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("model20240824.pt")
And here is my code for predict:
import json
import torch
from torch import nn
from PIL import Image
from torchvision.models import resnet34
from torchvision.transforms import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
image_path = "test_images/86810585ED8E85C2CE8525BB8E17CF07.jpg"
with open('birdinfo.json', 'r') as f:
data = f.read()
bird_info = json.loads(data)
def image_proprecess(img_path):
img = Image.open(img_path).convert('RGB')
data = data_transforms(img)
data = torch.unsqueeze(data, 0)
img_resize = img.resize((384, 384))
return img_resize, data
data_transforms = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# uncomment to use jit
# model = torch.jit.load('model20240824.pt')
# model = resnet34()
# model.fc = nn.Linear(model.fc.in_features, 11000)
# model.load_state_dict(torch.load('model20240824.pth'))
model = model.to(device).eval()
img, data = image_proprecess(image_path)
data = data.to(device)
outputs = model(data)
probs_full = torch.nn.functional.softmax(outputs, dim=1)
probs, indices = torch.topk(probs_full, k=5, dim=1)
probs = probs.squeeze().tolist()
indices = indices.squeeze().tolist()
indices_zh = [bird_info[indice][0] for indice in indices]
indices_en = [bird_info[indice][1] for indice in indices]
indices_scientific_name = [bird_info[indice][2] for indice in indices]
probs_round = [f'{round(prob * 100, 3)}%' for prob in probs]
print(indices)
# print(indices_zh)
print(indices_en)
print(indices_scientific_name)
print(probs_round)
test image:
results given by the original model, the top result Oriental Magpie-Robin (Copsychus saularis) is correct:
[8835, 6214, 8837, 9096, 8834]
['Oriental Magpie-Robin', 'Tropical Boubou', 'Madagascan Magpie-Robin', 'Pied Bush Chat', 'Indian Robin']
['Copsychus saularis', 'Laniarius major', 'Copsychus albospecularis', 'Saxicola caprata', 'Copsychus fulicatus']
['85.64%', '1.469%', '1.035%', '0.928%', '0.647%']
results given by the jit:
[1190, 10226, 8901, 2395, 2498]
['Eastern Buzzard', 'Altamira Oriole', 'Sunda Blue Flycatcher', 'Siau Scops Owl', 'Northern Hawk-Owl']
['Buteo japonicus', 'Icterus gularis', 'Cyornis caerulatus', 'Otus siaoensis', 'Surnia ulula']
['0.029%', '0.02%', '0.02%', '0.02%', '0.019%']
I am really confused. Could anyone please tell me what I did wrong? Greatly appreciate.