Model works fine while jit gives weird results

Hello everyone, I’ve trained a bird identification model and want to deploy it on mobile apps. So I converted it to torchscript. However, the jit model does not work correctly as the original model.

Here is the model, model20240824.pth is the weight dict, and birdinfo.json is the label map.

Here is my code used for export jit (my result):

import torch
from torch import nn
from torchvision.models import resnet34

from torch.utils.mobile_optimizer import optimize_for_mobile

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = resnet34()
model.fc = nn.Linear(model.fc.in_features, 11000)
model.load_state_dict(torch.load('model20240824.pth'))
model = model.to(device)

model.eval()

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("model20240824.pt")

And here is my code for predict:

import json

import torch
from torch import nn

from PIL import Image
from torchvision.models import resnet34
from torchvision.transforms import transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
image_path = "test_images/86810585ED8E85C2CE8525BB8E17CF07.jpg"

with open('birdinfo.json', 'r') as f:
    data = f.read()

bird_info = json.loads(data)


def image_proprecess(img_path):
    img = Image.open(img_path).convert('RGB')
    data = data_transforms(img)
    data = torch.unsqueeze(data, 0)
    img_resize = img.resize((384, 384))
    return img_resize, data


data_transforms = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

# uncomment to use jit
# model = torch.jit.load('model20240824.pt')

# model = resnet34()
# model.fc = nn.Linear(model.fc.in_features, 11000)
# model.load_state_dict(torch.load('model20240824.pth'))

model = model.to(device).eval()

img, data = image_proprecess(image_path)
data = data.to(device)

outputs = model(data)


probs_full = torch.nn.functional.softmax(outputs, dim=1)
probs, indices = torch.topk(probs_full, k=5, dim=1)
probs = probs.squeeze().tolist()

indices = indices.squeeze().tolist()
indices_zh = [bird_info[indice][0] for indice in indices]
indices_en = [bird_info[indice][1] for indice in indices]
indices_scientific_name = [bird_info[indice][2] for indice in indices]
probs_round = [f'{round(prob * 100, 3)}%' for prob in probs]

print(indices)
# print(indices_zh)
print(indices_en)
print(indices_scientific_name)
print(probs_round)

test image:

results given by the original model, the top result Oriental Magpie-Robin (Copsychus saularis) is correct:

[8835, 6214, 8837, 9096, 8834]
['Oriental Magpie-Robin', 'Tropical Boubou', 'Madagascan Magpie-Robin', 'Pied Bush Chat', 'Indian Robin']
['Copsychus saularis', 'Laniarius major', 'Copsychus albospecularis', 'Saxicola caprata', 'Copsychus fulicatus']
['85.64%', '1.469%', '1.035%', '0.928%', '0.647%']

results given by the jit:

[1190, 10226, 8901, 2395, 2498]
['Eastern Buzzard', 'Altamira Oriole', 'Sunda Blue Flycatcher', 'Siau Scops Owl', 'Northern Hawk-Owl']
['Buteo japonicus', 'Icterus gularis', 'Cyornis caerulatus', 'Otus siaoensis', 'Surnia ulula']
['0.029%', '0.02%', '0.02%', '0.02%', '0.019%']

I am really confused. Could anyone please tell me what I did wrong? Greatly appreciate.

torch: 2.1.1 (cpu)
torchvision: 0.16.1
Python 3.8.18

system:

                   -`                    name@arch-83al 
                  .o+`                   ----------------- 
                 `ooo/                   OS: Arch Linux x86_64 
                `+oooo:                  Host: 83AL XiaoXinPro 14 IRH8 
               `+oooooo:                 Kernel: 6.12.4-zen1-1-zen 
               -+oooooo+:                Uptime: 3 days, 4 hours, 48 mins 
             `/:-:++oooo+:               Packages: 1793 (pacman) 
            `/++++/+++++++:              Shell: bash 5.2.37 
           `/++++++++++++++:             Resolution: 2880x1800 
          `/+++ooooooooooooo/`           DE: Plasma 6.2.4 
         ./ooosssso++osssssso+`          WM: kwin 
        .oossssso-````/ossssss+`         WM Theme: Lavanda-Sea-Light 
       -osssssso.      :ssssssso.        Theme: [Plasma], FRESH-Blueberries [GTK2/3] 
      :osssssss/        osssso+++.       Icons: Fluent [Plasma], Fluent [GTK2/3] 
     /ossssssss/        +ssssooo/-       Terminal: pycharm-profess 
   `/ossssso+/:-        -:/+osssso+-     CPU: 13th Gen Intel i5-13500H (16) @ 4.700GHz 
  `+sso+:-`                 `.-/+oso:    GPU: Intel Raptor Lake-P [Iris Xe Graphics] 
 `++:.                           `-/+/   Memory: 19090MiB / 31816MiB 
 .`                                 `/
                                                                 
                                                                 

Could you compare the model accuracy in the same script where it’s exported?
If the accuracy matches, you could check the data preprocessing pipeline in your inference script next as the difference might come from this part of your code.

I deleted the lineoptimized_traced_model = optimize_for_mobile(traced_script_module) and fixed it. It seems that this issue has been reported in 2022.