Serialization in C++ Frontend API

Hello,
I have a question about using serialization API in C++. I want to load model’s weights in C++ program, and save them in python one.

In python I’m doing next:

class MaskRCNN(nn.Module):
...
  def save_weights():
    torch.save(self.state_dict(), self.checkpoint_path.format(epoch))
...

In C++ I’m trying to load saved weights in the next way:

torch::serialize::InputArchive archive;
archive.load_from(params_path);

But this code fails with exception:

tag == RecordTags::FOOTER ASSERT FAILED at /development/lib/sources/pytorch/caffe2/serialize/inline_container.h:234, please report a bug to PyTorch. File footer has wrong record type. Is this file corrupted? (readAndValidateFileFooter at /development/lib/sources/pytorch/caffe2/serialize/inline_container.h:234)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6c (0x7f6a66008e3c in /home/kirill/development/lib/lib/libc10.so)
frame #1: <unknown function> + 0x5ccc90 (0x7f6a7ce8dc90 in /home/kirill/development/lib/lib/libtorch.so.1)
frame #2: <unknown function> + 0x5cd6b4 (0x7f6a7ce8e6b4 in /home/kirill/development/lib/lib/libtorch.so.1)
frame #3: torch::jit::load(std::istream&) + 0x2af (0x7f6a7ce8c6ff in /home/kirill/development/lib/lib/libtorch.so.1)
frame #4: torch::jit::load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x51 (0x7f6a7ce8c961 in /home/kirill/development/lib/lib/libtorch.so.1)
frame #5: torch::serialize::InputArchive::load_from(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x21 (0x7f6a7d057c91 in /home/kirill/development/lib/lib/libtorch.so.1)
frame #6: main + 0x5ce (0x55bd69bd3538 in /home/kirill/development/mlcpp/mask_rcnn_pytorch/build/mask-rcnn)
frame #7: __libc_start_main + 0xf3 (0x7f6a64c7d223 in /usr/lib/libc.so.6)
frame #8: _start + 0x2e (0x55bd69bd2bae in /home/kirill/development/mlcpp/mask_rcnn_pytorch/build/mask-rcnn)

In python and can successfully load and use saved weights in the next way

model.load_state_dict(torch.load(weights_file_name))

Could you tell me if I misuse API, and it’s incompatible? What functions I should use to solve my problem, if they a exist now?

Thanks, in advance.

After some investigation I came up with a next solution: create a dictionary of lists with tensor values extracted from models state_dict, and save it as json, then parse this file in C++ and initialize torch::Tensor objects. Approach with torch.load\save doesn’t work in C++ because these functions are based on Python pickle. There are two other approaches to share models data between languages: torch.jit and onnx. But they relies on model execution tracing and it’s not easy to use them in case complex models like Mask-RCNN for example.

Implementation:

  1. In python:
 raw_state_dict = {}
 for k, v in model.state_dict().items():
     if isinstance(v, torch.Tensor):
         raw_state_dict[k] = (list(v.size()), v.numpy().tolist())
         break
     else:
         print("State parameter type error : {}".format(k))
         exit(-1)

 with open('mask_rcnn_coco.json', 'w') as outfile:
     json.dump(raw_state_dict, outfile)
  1. In C++:

#include <rapidjson/error/en.h>
#include <rapidjson/filereadstream.h>
#include <rapidjson/reader.h>

#include <iostream>
#include <stack>

namespace {

enum class ReadState {
  None,
  DictObject,
  ParamName,
  SizeTensorPair,
  TensorSize,
  SizeTensorPairDelim,
  TensorValue,
  List
};

struct DictHandler
    : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DictHandler> {
  DictHandler() {}

  bool Double(double d) {
    std::cout << "double " << d << std::endl;
    if (current_state_.top() == ReadState::List ||
        current_state_.top() == ReadState::TensorValue) {
      blob_.push_back(static_cast<float>(d));
      ++index_;
    } else {
      throw std::logic_error("Double parsing error");
    }
    return true;
  }

  bool Uint(unsigned u) {
    std::cout << "uint " << u << std::endl;
    if (current_state_.top() == ReadState::List ||
        current_state_.top() == ReadState::TensorValue) {
      blob_.push_back(static_cast<float>(u));
      ++index_;
    } else if (current_state_.top() == ReadState::TensorSize) {
      size_.push_back(static_cast<int64_t>(u));
    } else {
      throw std::logic_error("UInt parsing error");
    }
    return true;
  }

  bool Key(const char* str, rapidjson::SizeType length, bool /*copy*/) {
    key_.assign(str, length);
    std::cout << "key " << key_ << std::endl;
    if (current_state_.top() == ReadState::DictObject) {
      current_state_.push(ReadState::ParamName);
    } else {
      throw std::logic_error("Key parsing error");
    }
    return true;
  }

  bool StartObject() {
    std::cout << "start object" << std::endl;
    if (current_state_.top() == ReadState::None) {
      current_state_.pop();
      current_state_.push(ReadState::DictObject);
    } else {
      throw std::logic_error("Start object parsing error");
    }
    return true;
  }

  bool EndObject(rapidjson::SizeType /*memberCount*/) {
    std::cout << "end object" << std::endl;
    if (current_state_.top() != ReadState::DictObject) {
      throw std::logic_error("End object parsing error");
    }
    return true;
  }

  void StartData() {
    current_state_.push(ReadState::TensorValue);
    auto total_length = std::accumulate(size_.begin(), size_.end(), 1,
                                        std::multiplies<int64_t>());
    blob_.resize(static_cast<size_t>(total_length));
    blob_.clear();
    index_ = 0;
  }

  bool StartArray() {
    std::cout << "start array" << std::endl;
    if (current_state_.top() == ReadState::List) {
      current_state_.push(ReadState::List);
    } else if (current_state_.top() == ReadState::ParamName) {
      current_state_.push(ReadState::SizeTensorPair);
    } else if (current_state_.top() == ReadState::SizeTensorPair) {
      current_state_.push(ReadState::TensorSize);
      size_.clear();
    } else if (current_state_.top() == ReadState::SizeTensorPairDelim) {
      current_state_.pop();
      StartData();
    } else if (current_state_.top() == ReadState::TensorValue) {
      current_state_.push(ReadState::List);
    } else {
      throw std::logic_error("Start array parsing error");
    }
    return true;
  }

  bool EndArray(rapidjson::SizeType elementCount) {
    std::cout << "end array" << std::endl;
    if (current_state_.top() == ReadState::List) {
      current_state_.pop();
    } else if (current_state_.top() == ReadState::SizeTensorPair) {
      current_state_.pop();
      assert(current_state_.top() == ReadState::ParamName);
      current_state_.pop();
      std::cout << "Add new param" << std::endl;
      dict.push_back({key_, tensor_});
    } else if (current_state_.top() == ReadState::TensorSize) {
      current_state_.pop();
      if (elementCount == 0) {
        size_.push_back(1);
        StartData();
      } else {
        current_state_.push(ReadState::SizeTensorPairDelim);
      }
    } else if (current_state_.top() == ReadState::TensorValue) {
      current_state_.pop();
      assert(index_ == static_cast<int64_t>(blob_.size()));
      at::Tensor tensor_image = torch::from_blob(
          blob_.data(), at::IntList(size_), at::CPU(at::kFloat));
      if (blob_.size() == 1) {
        assert(current_state_.top() == ReadState::SizeTensorPair);
        current_state_.pop();
        assert(current_state_.top() == ReadState::ParamName);
        current_state_.pop();
        std::cout << "Add new param" << std::endl;
        dict.push_back({key_, tensor_});
      }
    } else {
      throw std::logic_error("End array parsing error");
    }
    return true;
  }

  std::string key_;
  std::vector<int64_t> size_;
  torch::Tensor tensor_;
  std::vector<float> blob_;
  int64_t index_{0};

  std::stack<ReadState> current_state_{{ReadState::None}};

  std::vector<std::pair<std::string, torch::Tensor>> dict;
};
}  // namespace

std::vector<std::pair<std::string, torch::Tensor>> LoadStateDict(
    const std::string& file_name) {
  auto* file = std::fopen(file_name.c_str(), "r");
  if (file) {
    char readBuffer[65536];
    rapidjson::FileReadStream is(file, readBuffer, sizeof(readBuffer));
    rapidjson::Reader reader;
    DictHandler handler;
    auto res = reader.Parse(is, handler);
    std::fclose(file);

    if (!res) {
      throw std::runtime_error(rapidjson::GetParseError_En(res.Code()));
    }

    return handler.dict;
  }
  return {};
}
1 Like

Hi, I was inspired by your solution, but I found JSON deserializing to be slow with large tensors.
So I used the similar logic but implemented a simple binary read write mechanism,
torch tensors can convert to numpy and numpy can convert to bytestring that’s how the python can write the underlying allocations data in binary file,
While in C++ the ifstream’s read method in binary mode allows us to pass a char* to char array of size length and reads length bytes in that array.
Since the tensor created in libtorch has same shape and size that of python tensor, its .numel() * .elemenetsize() yeilds the exact number of bytes that needs to be read from the file, and the tensor.sizes() can be passed to stride the inner data accordingly.

If you want I can post the snippet of my loading and saving code.
I think the load times and save times were faster

1 Like