Pytorch Mobile iOS Resnet50 (Not Computing)

I am using torch.jit to trace a pretrained vanilla resnet50 to import over to iOS and call using Pytorch Mobile // C++ on iOS.

- (NSInteger)predictImage:(void*)imageBuffer forLabels:(NSInteger)labelCount {
    int outputLabelIndex = -1;
    try {
        std::cout << "\npredictImage";
    at::Tensor tensor = torch::from_blob(imageBuffer, {1, 3, 224, 224}, at::kFloat);
        std::cout << "\npredictImageTwo";
    torch::autograd::AutoGradMode guard(false);
    at::AutoNonVariableTypeMode non_var_type_mode(true);
// Pass in image tensor to C++ scripted torch module
    **auto outputTensor = _impl.forward({tensor}).toTensor();**
        std::cout << "\nReceived outputTensor";

The line before I print received outputTensor never prints because I never receive anything from _impl.forward in the Predict Image function in my “TorchModule.mm” file in my iOS project.

Basically, I’ve gotten resnet18, 34, and mobilenet to work using Pytorch Mobile and the iOS demo but won’t receive the output tensor for resnet50 and above… is resnet50 supported? If someone is able to get a resnet50 working for mobile could they help?

Much appreciated!

#from torchvision.models import inception_v3
device = torch.device('cpu')

model = models.resnet50(pretrained=True)
model.load_state_dict(models.resnet50(pretrained=True).state_dict())
model = nn.Sequential(
    #ImageScale(),
    model,
    nn.Softmax(1)
)
model.eval()
input_tensor = torch.rand(1,3,224,224)
script_model = torch.jit.trace(model, input_tensor)
script_model.save("models/resnet50.pt")

Just for reference this is how I’m saving my model before taking it into iOS. Worked perfectly for resnet34 and resnet18

Hi Haris, this is a known issue. We’ve been working on fixing it. The problem is that the pthreadpool runs into a deadlock situation when running resnet50. There are a couple of work around for it, you can try

  • Set the number of thread to one in ThreadPool.cc
  • Use a different mutex in int ThreadPool::getNumThreads() const function
  • Use std::unique_lock<std::mutex> guard(executionMutex_, std::defer_lock); instead

Then recompile the PyTorch from source code by following the link here - https://pytorch.org/mobile/ios/#build-pytorch-ios-libraries-from-source.

Hi Tao!

Thanks so much for the response! For “3.” just to clarify you’re suggesting std:: unique_lock as an alternative mutex to use for the 2nd solution? Will try this tomorrow and update here. Also I don’t have to do both one and two do I? I can do either one xor two?

Sorry for the confusion, I just re-edited the comment. Actually, I have a fix being reviewed here - https://github.com/pytorch/pytorch/pull/29885. If you’d like to try it out, you can patch that PR.

1 Like

I cloned the Pytorch source repo and then followed the instructions to recompile/build iOS libraries (with the same file patch/changes as the PR)

I then replaced the install folder under Pods/Libtorch/Install in my iOS project with the newly compiled install folder; I don’t get any errors regarding path changes and Swift builds the project correctly.

However, when calling Predict Image on any model now (including previously working resnet34), I get the following check failederror:

Tracing into the Dispatcher.h in the repo, I find it’s breaking here

  const std::string dispatchKeyStr = toString(*dispatchKey);
  TORCH_CHECK(false, "Could not run '", dispatchTable.operatorName(), "' with arguments",
          " from the '", dispatchKeyStr, "' backend. '",
          dispatchTable.operatorName(), "' is only available for these backends: ",
          dispatchTable.listAllDispatchKeys(), ".");
}

Previously my Libtorch was installed from Cocopods specifically version 1.3.1. Is the version compiled from the repo I got from following the steps an equivalent version?

Using a 1.4 nightly build of pytorch for the Python jit trace, was working perfectly fine before.

I’ve seen false check errors solved by updating Pytorch or using nightly builds?

In this case for the patch how can I prevent this, or could I take a specific file with the pthread change from my newly compiled install folder and replace at Libtorch in my project?

Thanks in advance,
Haris

Hi Hussain, if you use BUILD_PYTORCH_MOBILE=1 IOS_ARCH=arm64 ./scripts/build_ios.sh to build your libraries, you shouldn’t see that error.

If you’re still seeing that, I believe that was come out this morning or later yesterday. Obviously, our mobile CI failed to do its job. I’m working on adding the simulator tests now. Sorry for the frustration. Will have updates here once we’ve fixed it.

Yep that’s exactly how I built it and got the issue.

Yeah please let me know here when the Mobile CI is fixed and what versions it should work on/I should be building, I would greatly appreciate it!

@HussainHaris

The master is back to normal. You can try recompiling from source code. Let me know if you have any questions.

Hey, I recompiled source code after pulling on Monday and it worked just fine! Thanks for all the help.

Just a quick question before closing thread, what are the limitations on the types of Pytorch models that can currently go mobile with a trace? Ex) Inception, Faster RCNN (for object detection), segmentation models with encoders and decoders etc…

Most models should be compatible with tracing or scripting. More details are at https://pytorch.org/docs/stable/jit.html . If any ScriptModule works on server but not mobile, we would consider that a bug.