Model with tensor and number operations errors in iOS

Traced a model and saved it using pytorch 1.3.1 with the following code.

class TestModule(torch.nn.Module):
    
    def forward(self, W):
        g = 2 * W
        return g

W = torch.rand(10)

test_model = torch.jit.trace(test_model, [W])
test_model.save("test_model.pt")

and loaded the model into iOS (libtorch 1.3.1) using the following code

torch::jit::script::Module testModel = torch::jit::load(filePath.UTF8String);

torch::IValue w = torch::IValue(torch::rand({10}));

testModel.forward({w});

and it gives the following error on forward

libc++abi.dylib: terminating with uncaught exception of type c10::Error: false CHECK FAILED at /Users/distiller/project/c10/core/Backend.h (tensorTypeIdToBackend at /Users/distiller/project/c10/core/Backend.h:106)
1 Like

Hi @mark_jimenez, can you add these two lines before running forward

torch::autograd::AutoGradMode guard(false);
at::AutoNonVariableTypeMode non_var_type_mode(true);

The first one tells the engine to disable autograd, the second one is sort of a workaround. We can get rid of it in 1.4.0 which will be released soon.

Let me know if you have any questions.

Hey @xta0. It worked when I added those two lines. Thank you!

I discovered a few more issues with the current build (1.3.1). Should I post them in a new thread, or is 1.4 going to be released soon so I can check them in that version?

@mark_jimenez you can post here in this thread, I’ll follow up. In 1.4.0, you still need this line - torch::autograd::AutoGradMode guard(false);, but at::AutoNonVariableTypeMode non_var_type_mode(true); is not necessary.

Thank you @xta0. I’ll ask one issue at a time so its not overwhelming.

I couldn’t include the libtorch in a unit test target. I filed an issue in cocoapods and they said it’s because libnnpack.as isn’t built as a universal library.

I’ve also filed the issue here where you can find more details: https://github.com/pytorch/pytorch/issues/32040

Thank you for the help.

Dear @xta0, I have a similar error.

I have exported a pre-trained model from python to c++ (model.pt) and it works perfectly in a c++ application.
I tried to use the same exported model in my iPad, and it gives the following error on torch::jit::load() method:

libc++abi.dylib: terminating with uncaught exception of type c10::Error: false CHECK FAILED at /Users/distiller/project/c10/core/Backend.h (tensorTypeIdToBackend at /Users/distiller/project/c10/core/Backend.h:106)
(no backtrace available)

When I add the two lines (to disable autograd and the second one) before running torch::jit::load(), it returns:

libc++abi.dylib: terminating with uncaught exception of type c10::Error: !v.defined() || v.is_variable() INTERNAL ASSERT FAILED at/Users/distiller/project/torch/csrc/jit/ir.h (t_ at /Users/distiller/project/torch/csrc/jit/ir.h:718)
(no backtrace available)

Do you have any idea how to solve this problem?

@mark_jimenez this is a known issue, because NNPACK doesn’t support the iOS simulator architecture, as is shown here - https://github.com/Maratyszcza/NNPACK . So for the simulator build, operators are not being run via NNPACK. However, @AshkanAliabadi in our team has been actively working on XNNPACK, which will replace NNPACK in the future. Sorry for the inconvenience.

@fabricionarcizo What version of libtorch were you using? My guess is that your desktop version of PyTorch didn’t match the version of your libtorch. This will affect how your torchscript model is generated. You can verify the version by typing the command below
torch.version.__version__

@xta0 I have used the pytorch 1.3.1 and libtorch 1.3.1. I generated the model in Python, and used torch.jit.trace and torch.jit.save to save the model in a .pt file. I’m able to load the model using torch::jit::load method in the C++ code. However, the same method (torch::jit::load) doesn’t work in the Objective-C version (also 1.3.1).

@fabricionarcizo Is it OK to paste your python code here (or somewhere I can see)? So that I can debug. Because from your description, I’ve no idea of what could go wrong.

@xta0 Thank you for the update on NNPACK. It’s alright, I understand that porting NNPACK is a big project. I’ll be patient for any updates.

BTW, congratulations on the 1.4 release. It fixed some of the problems I was going to ask from 1.3.1.

I wanted to ask if on device training and/or Swift and Objective-C API will be supported in the next release? Or at least is on the 1.5 branch. I see that the podspec on the pytorch github repo is at 1.5 (https://github.com/pytorch/pytorch/blob/master/ios/LibTorch.podspec). Is there any way we can work on that version?

@mark_jimenez Thanks. We have teams working on enabling on-deivce training, but I’m not sure if that can be released in 1.5.0. As for the API wrappers, we have a proposal internally - https://github.com/pytorch/pytorch/pull/25541. We’ve been proactively collecting feedbacks from communities, but haven;t decided when to release it, so feel free to submit ideas, proposals, etc. The 1.5.0 version in .podspec is just a placeholder, nothing particular has been done on that branch.

@xta0 Thank you for the update!

We’re working on a workaround right now for training in mobile by updating the weights of the model by ourself in torchscript. We can’t get it to work on libtorch 1.4 on iOS (not sure on Android). I think it may have something to do with how autograd is implemented on mobile.

For example here’s the python code compiled in PyTorch 1.4.

class TestModule(torch.nn.Module):
    def forward(self, x, y):
        z = x + y
        L = z.sum()
        L.backward()
        return x.grad, y.grad

model = torch.jit.script(TestModule(), torch.tensor([1.]), torch.tensor([1.])) 

And in iOS, the model is loaded like this


    torch::jit::script::Module testModel = torch::jit::load(testFilePath.UTF8String);

    auto result = testModel.forward({torch::rand({1}, torch::TensorOptions().requires_grad(true)), torch::rand({1}, torch::TensorOptions().requires_grad(true))});

I get the same error as the one you filed here: https://github.com/pytorch/pytorch/pull/30067

So reading that, I guess autograd isn’t implemented for the mobile builds? Is that still the case for libtorch 1.4?

@mark_jimenez you’re right. The autograd is not available on mobile so far.

@fabricionarcizo @xta0 I am having the same issue. The program breaks at this line in module.h file:

IValue forward(std::vector<IValue> inputs) {
    return get_method("forward")(std::move(inputs)); // here it breaks
  }

PyTorch: 1.4.0
LibTorch: 1.4.0
iOS 13.5
XCode 11.5

Did you manage to resolve this? I happened to notice that PyTorch iOS tutorial project on PyTorch website “HelloWorld” also breaks in release mode. The model runs okay in debug mode.

@Harsh_Thaker thanks for reporting. I was able to repro, looks like the the crash was in pthreadpool when invoking convolution via NNPACK

* thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=1, address=0x10dcbd3a4)
    frame #0: 0x00000001013a83b4 HelloWorld`compute_input_packing + 256
    frame #1: 0x000000010108b2cc HelloWorld`std::__1::function<void (int, unsigned long)>::operator()(int, unsigned long) const + 48
    frame #2: 0x000000010108b884 HelloWorld`caffe2::ThreadPool::run(std::__1::function<void (int, unsigned long)> const&, unsigned long)::FnTask::Run() + 44
    frame #3: 0x000000010108b530 HelloWorld`caffe2::WorkersPool::Execute(std::__1::vector<std::__1::shared_ptr<caffe2::Task>, std::__1::allocator<std::__1::shared_ptr<caffe2::Task> > > const&) + 304
    frame #4: 0x000000010108b11c HelloWorld`caffe2::ThreadPool::run(std::__1::function<void (int, unsigned long)> const&, unsigned long) + 740
    frame #5: 0x00000001010885e4 HelloWorld`pthreadpool_compute_1d + 84
    frame #6: 0x0000000101087f00 HelloWorld`pthreadpool_compute_2d_tiled + 148
    frame #7: 0x00000001013a68b0 HelloWorld`nnp_convolution_inference + 3168
    frame #8: 0x0000000100aca13c HelloWorld`at::native::_nnpack_spatial_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long long>, c10::ArrayRef<long long>)::$_0::operator()(unsigned long) const + 680
    frame #9: 0x0000000100ac9284 HelloWorld`at::native::_nnpack_spatial_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long long>, c10::ArrayRef<long long>) + 1700
    frame #10: 0x0000000100d075a4 HelloWorld`at::TypeDefault::_nnpack_spatial_convolution(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long long>, c10::ArrayRef<long long>) + 248
...

Can you confirm whether you were seeing the same stack trace? If so, I belive that issue has been resolved in

Both of them wiil be merged to our 1.6.0 release.

#21 0x00000001030f86d0 in torch::jit::script::Module::forward(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >) at ios-demo-app-master/HelloWorld/HelloWorld/Pods/LibTorch/install/include/torch/csrc/jit/script/module.h:113

That’s the one in my case. I could solve the problem by setting Optimization level flag for release in Swift compiler to None from default Optimize for speed. What can be the issue with that?