Hello,
I have successfully traced a model in Python and wanted to load it now in my C++ app.
I did so successfully using torch::jit::load(filepath, device)
.
However, I need to load the .pt
model file now from a const char* modelScript
. Its size is given as const size_t modelScriptSize
.
For this purpose, I created a helper class that subclasses std::streambuf
and wraps the buffer. With it, I am creating a std::istream
and supply that to torch::jit::load(is, device)
.
utils::MemReader mr(modelScript, modelScriptSize);
std::istream is(&mr);
mod_ = torch::jit::load(is, device_);
I am receiving an error:
istream reader failed: checking archive.
Exception raised from validate at /tmp/pytorch/pytorch/caffe2/serialize/istream_adapter.cc:32 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7f1d1b9447ac in /usr/local/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x7f1d1b910866 in /usr/local/lib/libc10.so)
frame #2: caffe2::serialize::IStreamAdapter::validate(char const*) const + 0x17b (0x7f1d1e22dbeb in /usr/local/lib/libtorch_cpu.so)
frame #3: caffe2::serialize::IStreamAdapter::read(unsigned long, void*, unsigned long, char const*) const + 0x41 (0x7f1d1e22dd21 in /usr/local/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x3f9c02b (0x7f1d1f91c02b in /usr/local/lib/libtorch_cpu.so)
frame #5: torch::jit::load(std::shared_ptr<caffe2::serialize::ReadAdapterInterface>, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x6c (0x7f1d1f918a5c in /usr/local/lib/libtorch_cpu.so)
frame #6: torch::jit::load(std::istream&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0xc2 (0x7f1d1f91aaa2 in /usr/local/lib/libtorch_cpu.so)
frame #7: torch::jit::load(std::istream&, c10::optional<c10::Device>) + 0x6a (0x7f1d1f91ab8a in /usr/local/lib/libtorch_cpu.so)
I verified that my stream outputs the exact same content as the .pt
file, by writing my istream into a file byte by byte and then comparing them via md5sum
on linux. It yielded the same hash.
Anybody able to spot my error here?
Here is my code for the MemReader
helper class:
class MemReader : public std::streambuf {
public:
MemReader(const char* data, size_t size);
private:
int_type underflow();
int_type uflow();
int_type pbackfail(int_type ch);
std::streamsize showmanyc();
const char* const begin_;
const char* const end_;
const char* current_;
};
MemReader::MemReader(const char* data, size_t size) :
begin_(data),
end_(data + size),
current_(data)
{}
MemReader::int_type MemReader::underflow() {
if (current_ == end_) {
return traits_type::eof();
}
return traits_type::to_int_type(*current_);
}
MemReader::int_type MemReader::uflow() {
if (current_ == end_) {
return traits_type::eof();
}
return traits_type::to_int_type(*current_++);
}
MemReader::int_type MemReader::pbackfail(int_type ch) {
if (current_ == begin_ || (ch != traits_type::eof() && ch != current_[-1])) {
return traits_type::eof();
}
return traits_type::to_int_type(*--current_);
}
std::streamsize MemReader::showmanyc() {
return end_ - current_;
}