Running resnet18 on my macbook pro with Mojave 10.14.5 (cpu) can take around .2 seconds or close to 1.5 seconds depending on the pytorch version I have installed. Using python 3.7.3
With version 1.1.0, inference takes .2
Anything newer takes over 1 second.
Ive create a new pyenv virtualenv with 3.7.3 and would install the latest pytorch and torchvision with: python -m pip install torch torchvision
And the script would give me the following times:
1.445391
1.46278
1.43807
1.437759
1.442724
1.437736
1.438425
1.449995
1.512187
1.452032
AVG: 1.4517098999999998
I create a new virtual env and install an older pytorch with: python -m pip install torch==1.1.0 torchvision==0.4.0
And the script would give me the following times:
0.216007
0.218446
0.251732
0.210863
0.209307
0.220819
0.208335
0.210164
0.211357
0.207779
AVG: 0.2164809
my script…
import torch
from torchvision import transforms
from torchvision.models import resnet18
import requests
from io import BytesIO
from PIL import Image
from datetime import datetime
processor = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def ensure_compatible_sd(sd):
from collections import OrderedDict
new_state_dict = OrderedDict()
return_new_sd = False
for key, val in sd.items():
if key[:4] == "net.":
return_new_sd = True
new_state_dict[key[4:]] = val
return new_state_dict
def test_model(model, img_in):
all_time = []
model.eval()
with torch.no_grad():
for i in range(10):
start = datetime.now()
out = model(img_in)[0]
secs = (datetime.now()-start).total_seconds()
all_time.append(secs)
print(secs)
print(f"AVG: {sum(all_time)/len(all_time)}")
sdict_og = torch.load("best_checkpoint.pt", map_location='cpu')
sdict = ensure_compatible_sd(sdict_og)
reg_res = resnet18()
reg_res.fc = torch.nn.Linear(512, 17)
reg_res.load_state_dict(sdict)
image_url = 'https://www.opposuits.com/media/catalog/product/cache/16/image/550x/925f46717e92fbc24a8e2d03b22927e1/o/s/osui-0045_retro_suit_pac_man_1_v2.jpg'
img_content = requests.get(image_url).content
img = Image.open(BytesIO(img_content))
img_in = processor(img).unsqueeze(0)
test_model(reg_res, img_in)
Any idea what might be going on?