Torch.utils.mobile_optimizer.optimize_for_mobile is resulting different output than torch model and jit model

When I convert torch model for mobile optimizer, it is giving different outputs. But the output is the same for torch model and jit model.

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

# model definition with two classes
model = torchvision.models.mobilenet_v3_large()
model.classifier[3] = torch.nn.Linear(in_features=1280, out_features=2)
model.load_state_dict(torch.load('model.ptl'))
model.eval()
# Optimize model for mobile

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)

traced_script_module_optimized = optimize_for_mobile(traced_script_module)

  1. INFERENCE TORCH MODEL
model.forward(torch.ones([1, 3, 224, 224], dtype=torch.float))

# Output
tensor([[-3.0561,  3.0894]], grad_fn=<AddmmBackward0>)
  1. INFERENCE JIT MODEL
traced_script_module.forward(torch.ones([1, 3, 224, 224], dtype=torch.float))

# Output
tensor([[-3.0561,  3.0894]], grad_fn=<AddmmBackward0>)
  1. INFERENCE OPTIMZED MODEL
traced_script_module_optimized.forward(torch.ones([1, 3, 224, 224], dtype=torch.float))

# Output
tensor([[-4.5466,  6.5033]])

I tested with several examples. The same result: after optimizing for mobile, the model’s output is different

I saw @ptrblck responses to many errors and problems related to @PyTorch. Could you please look at this problem. Why the trained model is giving different output after optimizing for mobile method?

I cannot reproduce the difference in the model outputs using the current nightly release:

'1.12.0.dev20220501+cu116'

and get:

tensor([[-0.0082, -0.0152]], grad_fn=<AddmmBackward0>)
tensor([[-0.0082, -0.0152]], grad_fn=<AddmmBackward0>)
tensor([[-0.0082, -0.0152]])

However, I would also recommend to call the models directly (not their forward method) via:

out1 = model(torch.ones([1, 3, 224, 224], dtype=torch.float))
out2 = traced_script_module(torch.ones([1, 3, 224, 224], dtype=torch.float))
out3 = traced_script_module_optimized(torch.ones([1, 3, 224, 224], dtype=torch.float))

as calling the forward would skip forward hooks etc.
Could you use the latter approach and also check the current nightly release to see if you are still hitting this issue, please?

PS: I’m using the randomly initialized models as I don’t have access to your state_dict.

1 Like

Thank you very much @ptrblck for your quick response. I tried with nightly release and called the models directly without forward method:

'1.12.0.dev20220502+cu116'

and get:

tensor([[-0.1445,  0.1022]], grad_fn=<AddmmBackward0>)
tensor([[-0.1445,  0.1022]], grad_fn=<AddmmBackward0>)
tensor([[-184.6089,  197.5092]])

As two pytorch versions were giving two different outputs with optimize_for_mobile, I tried with older version:

'1.7.1+cu11.0'

and get:

tensor([[-0.1445,  0.1022]], grad_fn=<AddmmBackward>)
tensor([[-0.1445,  0.1022]], grad_fn=<AddBackward0>)
tensor([[-0.1445,  0.1022]], grad_fn=<AddBackward0>)

It is strange but the older torch version is working well and producing the same results.

Are you also seeing the different results if you are using a randomly initialized model or only if the state_dict is loaded? It’s concerning that the current nightly creates such a high mismatch.

@ptrblck Even with randomly initialized models, the results are different. Here is the result of Resnet50. I just downloaded pretrained model and inferenced it randomly:

With:

'1.12.0.dev20220502+cu116'
model = torchvision.models.resnet50(pretrained=True)
model.eval()

test with random tensors

out1 = model(torch.ones([1, 3, 224, 224], dtype=torch.float))
out2 = traced_script_module(torch.ones([1, 3, 224, 224], dtype=torch.float))
out3 = traced_script_module_optimized(torch.ones([1, 3, 224, 224], dtype=torch.float))

and results:

tensor([[-3.0806e-01,  7.9845e-02, -1.1900e+00, -1.4837e+00, -5.1359e-01,...]])
tensor([[-3.0806e-01,  7.9845e-02, -1.1900e+00, -1.4837e+00, -5.1359e-01,...]])
tensor([[-5.2194e+00,  1.4939e+00, -2.2703e+00, -5.4212e+00, -2.0551e+00,...]])

There are 1000 class labels on pretrained Resnet50, so I just put the first fives.


Also, I tested it with torch version '1.10.1'+cu11.3, and it is working just fine.

tensor([[-3.0806e-01,  7.9845e-02, -1.1900e+00, -1.4837e+00, -5.1359e-01,...]])
tensor([[-3.0806e-01,  7.9845e-02, -1.1900e+00, -1.4837e+00, -5.1359e-01,...]])
tensor([[-3.0806e-01,  7.9845e-02, -1.1900e+00, -1.4837e+00, -5.1359e-01,...]])

In my opinion, the problem started with stable torch version 1.11.0 and the following Nightly versions.

It looks indeed like a bug. Could you create an issue on GitHub with a link to this topic and the minimal code snippet to reproduce the error, please?

@ptrblck Here I opened the issue on GitHub

Is it a quantized model?

No. It’s just a pytorch model