How to print out model weights in iOS Libtorch?

Seems like the stuff that people are doing in C++ doesn’t quite work for iOS. For example, I tried:

for (const auto& params : module.parameters()) {
          std::cout << params.values() << std::endl;
}

which yields:

values not implemented for TensorTypeSet(VariableTensorId, CPUTensorId) (values at /Users/distiller/project/build_ios/aten/src/ATen/core/TensorMethods.h:3534) (no backtrace available)

Turns out params is a tensor so the following works:

  for (const auto& params : _impl.parameters()) {
          params.print();
  }

However this only outputs the tensor shapes.

can you try params[i][j].item<float>() or use the buffer pointer float* buffer = params.data_ptr<float>()

Thanks @xta0, that works great. I am using your iOS sample project as a template, and have run into a big problem that I was wondering if you could help with (I can make another thread if needed).

The traced model outputs very different results in Swift than it does in Python, and I don’t know why. I followed your steps in the sample project:

First, trace the model and see the output on a tensor of ones:

model = ExplicitSpeechResModel(36, 19)
model.load('../output/model_1h.pt')
model.eval()
ones = torch.ones([1, 101, 40])
traced_script_module = torch.jit.trace(model, ones)
traced_script_module.save("../output/traced_model.pt")
traced_script_module(ones)

//outputs:
//tensor([[-18.7624, -30.3478, -31.0299,  -1.1888,   8.6857,  33.7217, -36.2783,
// 51.8277,  55.9391,   9.0642,   8.6428,  10.3509,   0.2688,  43.9576,
// -7.1114, -55.3318, -16.7983, -13.5788,  -3.9336,  -1.1792,  14.3855,
//-31.8519, 101.3712, -43.9597,  40.5726, -16.2946, -15.8538,  21.1088,
//-31.5852, -14.2146, -14.5817,  19.9373, -21.5292,   9.4006, -45.0686,
// 21.4724]], grad_fn=<DifferentiableGraphBackward>)

On the Swift side, here is my TorchModule.mm. As you will see, the output on the same tensor is completely different:

@implementation TorchModule {
 @protected
  torch::jit::script::Module _impl;
}

- (nullable instancetype)initWithFileAtPath:(NSString*)filePath {
  self = [super init];
  if (self) {
    try {
      auto qengines = at::globalContext().supportedQEngines();
      if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
        at::globalContext().setQEngine(at::QEngine::QNNPACK);
      }
      _impl = torch::jit::load(filePath.UTF8String);
      _impl.eval();
    } catch (const std::exception& exception) {
      NSLog(@"%s", exception.what());
      return nil;
    }
  }
  return self;
}


- (NSArray<NSNumber*>*)predictImage:(void*)imageBuffer {
  try {

    // try out dummy input of all ones
    at::Tensor tensor = torch::ones({1, 101, 40});
    torch::autograd::AutoGradMode guard(false);
    at::AutoNonVariableTypeMode non_var_type_mode(true);
    auto outputTensor = _impl.forward({tensor}).toTensor();
    float* floatBuffer = outputTensor.data_ptr<float>();
    if (!floatBuffer) {
      return nil;
    }
    NSMutableArray* results = [[NSMutableArray alloc] init];
    for (int i = 0; i < 36; i++) {
      [results addObject:@(floatBuffer[i])];
    }
    return [results copy];
  } catch (const std::exception& exception) {
    NSLog(@"%s", exception.what());
  }
  return nil;
}
@end

Calling predictImage results in the following vector:

[11.918683, -21.391111, -18.756794, -15.70252, -14.593732, 28.798603, -22.37965, 10.117706, 5.1135015, -8.376111, -25.258512, -7.270096, 0.9224758, 3.4262152, 28.566887, 2.90841, 25.247177, 35.124638, 14.7190695, -37.291008, -4.821145, 33.09956, 47.47553, 11.395653, 9.54897, -5.713372, -32.897644, -18.26301, -5.596691, -18.339537, -25.02614, -23.303043, -3.3603168, 31.69397, 3.0528922, 7.3663263]

It is totally different from the one in Python. The model weights seem to be the same, I’m calling model.eval() … not sure what else is missing. Here is my model architecture in case it is helpful:

class ExplicitSpeechResModel(SerializableModule):
    def __init__(self, n_labels, n_maps):
        super().__init__()
        self.conv0 = nn.Conv2d(1, n_maps, (3, 3), padding=(1, 1), bias=False)
        self.conv1 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(0 // 3)), dilation=int(2**(0 // 3)), bias=False)
        self.bn1 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv2 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(1 // 3)), dilation=int(2**(1 // 3)), bias=False)
        self.bn2 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv3 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(2 // 3)), dilation=int(2**(2 // 3)), bias=False)
        self.bn3 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv4 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(3 // 3)), dilation=int(2**(3 // 3)), bias=False)
        self.bn4 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv5 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(4 // 3)), dilation=int(2**(4 // 3)), bias=False)
        self.bn5 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv6 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(5 // 3)), dilation=int(2**(5 // 3)), bias=False)
        self.bn6 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv7 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(6 // 3)), dilation=int(2**(6 // 3)), bias=False)
        self.bn7 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv8 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(7 // 3)), dilation=int(2**(7 // 3)), bias=False)
        self.bn8 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv9 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(8 // 3)), dilation=int(2**(8 // 3)), bias=False)
        self.bn9 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv10 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(9 // 3)), dilation=int(2**(9 // 3)), bias=False)
        self.bn10 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv11 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(10 // 3)), dilation=int(2**(10 // 3)), bias=False)
        self.bn11 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv12 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(11 // 3)), dilation=int(2**(11 // 3)), bias=False)
        self.bn12 = nn.BatchNorm2d(n_maps, affine=False)
        self.conv13 = nn.Conv2d(n_maps, n_maps, (3, 3), padding=int(2**(12 // 3)), dilation=int(2**(12 // 3)), bias=False)
        self.bn13 = nn.BatchNorm2d(n_maps, affine=False)
        self.output = nn.Linear(n_maps, n_labels)

    def forward(self, x):
        x = x.unsqueeze(1)
        y0 = F.relu(self.conv0(x))
        x = self.bn1(F.relu(self.conv1(y0)))
        y2 = F.relu(self.conv2(x)) + y0
        x = self.bn2(y2)
        x =  self.bn3(F.relu(self.conv3(x)))
        y4 = F.relu(self.conv4(x)) + y2
        x = self.bn4(y4)
        x = self.bn5(F.relu(self.conv5(x)))
        y6 = F.relu(self.conv6(x)) + y4
        x = self.bn6(y6)
        x =  self.bn7(F.relu(self.conv7(x)))
        y8 = F.relu(self.conv8(x)) + y6
        x = self.bn8(y8)
        x = self.bn9(F.relu(self.conv9(x)))
        y10 = F.relu(self.conv10(x)) + y8
        x = self.bn10(y10)
        x =  self.bn11(F.relu(self.conv11(x)))
        y12 = F.relu(self.conv12(x)) + y10
        x = self.bn12(y12)
        x =  self.bn13(F.relu(self.conv13(x)))
        x = x.view(x.size(0), x.size(1), -1) # shape: (batch, feats, o3)
        x = torch.mean(x, 2)
        return self.output(x)

Sorry for the long post. I really appreciate your help so far, and the sample project you made as well.

This seems to be related to https://github.com/pytorch/pytorch/pull/39591. Can you manually apply the PR and recompile from the source? Please refer to the section here - https://pytorch.org/mobile/ios/#build-pytorch-ios-libraries-from-source

1 Like

Thanks. I’ll let you know how it goes!

I’m getting a bunch of errors like:

Undefined symbols for architecture x86_64:
  "c10::IValue::toModule() const", referenced from:
      torch::jit::slot_iterator_impl<torch::jit::detail::ModulePolicy>::next() in TorchModule.o
      torch::jit::slot_iterator_impl<torch::jit::detail::ParameterPolicy>::next() in TorchModule.o
  "c10::ivalue::Object::type() const", referenced from:
      torch::jit::slot_iterator_impl<torch::jit::detail::ModulePolicy>::valid() const in TorchModule.o
      torch::jit::slot_iterator_impl<torch::jit::detail::ModulePolicy>::next() in TorchModule.o
      torch::jit::slot_iterator_impl<torch::jit::detail::ParameterPolicy>::valid() const in TorchModule.o
      torch::jit::slot_iterator_impl<torch::jit::detail::ParameterPolicy>::next() in TorchModule.o

Seems like something went wrong with the initial build. I followed the posted instructions, and added a few steps.

  1. Pod install LibTorch 1.4.0 then replace the install folder in Pods/LibTorch with the install generated by the build.
  2. Change C++ compiler version in XCode to C++14.

Do you know what the problem might be?

If it is relevant, my cmake version is 3.12.3, and clang version is 11.0.3. I’m using Mac OS Catalina.

you’re not supposed to use cocoapods to recompile. Can you try

  1. Clone pytorch and check out the master branch
  2. Manually apply the changes in the PR (only one liner)
  3. Follow the section https://pytorch.org/mobile/ios/#build-pytorch-ios-libraries-from-source to recompile from source code and setup XCode

I did as you suggested, and am still getting these errors. Seems like Xcode can’t find some header file.

Undefined symbols for architecture x86_64: "torch::jit::Object::find_method(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) const", referenced from: torch::jit::Object::get_method(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) const in TorchModule.o "at::GradMode::is_enabled()", referenced from: at::AutoGradMode::AutoGradMode(bool) in TorchModule.o "at::GradMode::set_enabled(bool)", referenced from: at::AutoGradMode::AutoGradMode(bool) in TorchModule.o at::AutoGradMode::~AutoGradMode() in TorchModule.o

I tried replacing the force_load flag with all_load and Objc but it has the same result. My header search path is $(PROJECT_DIR)/install/include. If I toggle this to “recursive”, the above errors disappear and are replaced by errors like:

No member named 'signbit' in the global namespace; did you mean 'sigwait'?

Seems like the recursive setting makes it so that cmath is incorrectly loaded. I have been trying to fix this for a while, and I was wondering if you had any suggestions! Thanks for the help.

Did you add these two guards?

at::AutoNonVariableTypeMode nonVarTypeModeGuard(true);
torch::autograd::AutoGradMode guard(false);

Yes, I had those guards. The error is in compile time. Here is a minimal reproduction of the bug, if it helps.

@Bryan_Wang I’m kinda busy recently, I’ll ask my colleague to followup with you

1 Like

Hi @Bryan_Wang, as @xta0 mentioned, the wrong prediction issue should have been fixed via PR 39591. Now I’m looking into the problem you have now. Thanks for your patience.

1 Like

Please follow the tutorial iOS | PyTorch very carefully as @xta0 mentioned. Double check you’ve changed the value of Header Search Paths, other linker flags, and bitcode accordingly before you build it. When you drag install folder to the project, please make sure those three check boxes are checked as the screenshot shows, this is important.

If you dragged in successfully, the color of the folder icon of install in the side bar should be yellow instead of blue.

I also attached a successfully built project which is based on the minimal project you attached (not sure if PR 39591 is included or not, just for showing you the project file structure and settings). Hopefully it will help. Dropbox - File Deleted - Simplify your life The password is pytorchios

2 Likes

Thank you so much for your help. I have been traveling for the past two weeks so unable to respond.

Would it be possible for you to reshare the link? I am getting “The owner hasn’t granted you access to this link” with no password prompt.

https://www.dropbox.com/s/mqe3pyme2mnh8o6/PytorchDev-Fixed.zip?dl=0 BTW, PyTorch 1.6 now has been released, so it is supposed to be working as well if you import PyTorch via Cocoapod. Thanks.

1 Like

Ah, I’ll do that as well. Thank you.