Loading BatchNorm1d with JIT

Hi!

I’m trying to load a PyTorch model in C++, using JIT. The model is defined as follows:

class JitModel(torch.jit.ScriptModule):
    def __init__(self):
        super(JitModel, self).__init__()

        self.n_layers = 5
        self.n_features = 14
        self.fc1 = torch.nn.Linear(14, 14)
        self.fc2 = torch.nn.Linear(14, 14)
        self.fc3 = torch.nn.Linear(14, 14)
        self.fc4 = torch.nn.Linear(14, 14)
        self.fc5 = torch.nn.Linear(14, 14)
        self.out = torch.nn.Linear(14, 1)
        self.normalise = torch.nn.BatchNorm1d(14)

    @torch.jit.script_method
    def forward(self, x):
        _x = x
        # _x = self.normalise(_x)
        _x = F.relu(self.fc1(_x))
        _x = F.relu(self.fc2(_x))
        _x = F.relu(self.fc3(_x))
        _x = F.relu(self.fc4(_x))
        _x = F.relu(self.fc5(_x))
        _x = torch.sigmoid(self.out(_x))
        return _x

where

F = torch.nn.functional

This model is used for training in a Python3 script and then saved using the JIT save() function.
When loaded with

torch.jit.load("model.pt") (in Python)

the model is loaded correctly.
When loaded with

torch::jit::load("model.pt") (in C++)

this happens:

terminate called after throwing an instance of 'torch::jit::script::ErrorReport'
  what():
Return value was annotated as having type Tuple[] but is actually of type Optional[Tuple[]]:
op_version_set = 0
def _check_input_dim(self,
    input: Tensor) -> Tuple[]:
  _0 = torch.ne(torch.dim(input), 2)
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~...  <--- HERE
  if _0:
    _1 = torch.ne(torch.dim(input), 3)
  else:
    _1 = _0
  if _1:
    ops.prim.RaiseException("Exception")
  else:
    pass
def forward(self,
Aborted (core dumped)

If I comment the line:

        self.normalise = torch.nn.BatchNorm1d(14)

in the definition, the C++ script loads the model correctly.
Really no idea of what’s wrong with this implementation.

This looks like a bug in our side. I’ve failed a Github issue, feel free to follow along there.

Thanks for the answer @Michael_Suo, I’ll follow.

Hi @Anthair,

Thanks for raising this issue, i tried to reproduce it, but it works all fine on my side (even if with self.normalise uncommented), c++ frontend loads the model correctly. Can you verify if this is still the case in our latest nightly? If it is still the case, can you share your environment? it might be a environment only issue.

Hi @wanchaol,

thank you for the help.
Using the latest (20190222) nightly, the model is loaded in C++ (or at least, it doesn’t crash when calling the load function). However, it still crashes with the stable 1.0.0.
I’ll cleanup the environment and try again with the stable.

Thanks again for helping with the issue.

@Anthair 1.0.1 is our latest stable version as it contains bunch of bug fixes from 1.0.0, please feel free to try out 1.0.1 and see if the error still bumps out or not. If you want to try out our latest feature, you can stick to our nightlies or build the master on your own :slight_smile:

@wanchaol, thanks for the link. I downloaded 1.0.1 and I can confirm that the ::Load function does not crash. However, using either the nightly 20190222 or the stable 1.0.1, using the model for inference results in a crash:

Model Loaded
0x12321b0
terminate called after throwing an instance of 'torch::jit::JITException'
  what():
Exception:
operation failed in interpreter:
op_version_set = 0
def forward(self,
    x: Tensor) -> Tensor:
  _0 = torch.ne(torch.dim(x), 2)
  if _0:
    _1 = torch.ne(torch.dim(x), 3)
  else:
    _1 = _0
  if _1:
    ops.prim.RaiseException("Exception")
    ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  else:
    pass
  _2 = bool(self.normalise.training)
  if _2:
    _3 = True
  else:
    _3 = _2
  if _3:
    _4 = torch.add_(self.normalise.num_batches_tracked, 1, 1)
Aborted (core dumped)

The loading and inference code is:

#include <torch/script.h> // One-stop header.
#include <ATen/ATen.h>

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }

  auto module = torch::jit::load(argv[1]);

  assert(module != nullptr);
  std::cout << "Model Loaded\n";

  std::vector<torch::jit::IValue> inputs;
  inputs.push_back(torch::rand({14,}));

  std::cout << module << std::endl;

  auto output = module->forward(inputs);
}

I ran it on two different, totally independent setups, with the same results.
Here is the training file, for testing:
https://drive.google.com/file/d/1jEgTXMRFkUa1pynOK_rP0Lir39nzhtt8/view?usp=sharing

@Anthair thanks a lot for the follow up, I will try to reproduce it and get back to you with more details

@Anthair OK I looked into your code, When you have batchnorm1d, your input must be 2d or 3d, refer to the doc here, changing the input to something like torch::rand(2, 14) works well.

@wanchaol, ok it was a stupid error on my side. Thanks for the help, now the inference works fine, I’d say we can close.

PS: I really would like to check the weights and the general state of the model, is there something similar to load_state_dict() in C++? That would make eveything easier.

Yeah for the load_state_dict thing, I think it should be on our roadmap for c++ frontend. Feel free to create a issue to get on track :slight_smile:

1 Like