Importing weights from jit::Module to torch::nn::Module

Hi,
I have a trained NN compiled in jit format (I performed the training with PyTorch). I have re-implemented such network in libtorch. Now I want to initialize the weights of the libtorch network using the weights of the pre-trained one.
How can I solve?

I don’t think the load_state_dict method was implemented in libtorch as described in this issue but you could try to use the proposed workaround.

Thanks,
I tried the proposed solution, however with both .pt and .pth I get errors. In details if I use a .pth file with weights I get:

terminate called after throwing an instance of 'c10::Error'
  what():  isGenericDict()INTERNAL ASSERT FAILED at "../aten/src/ATen/core/ivalue_inl.h":1904, please report a bug to PyTorch. Expected GenericDict but got None
Exception raised from toGenericDict at ../aten/src/ATen/core/ivalue_inl.h:1904 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xa0 (0xfffff7f0be30 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf8 (0xfffff7eed98c in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x5c (0xfffff7f09ce4 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #3: c10::IValue::toGenericDict() const & + 0xe4 (0xfffff413a3ec in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::jit::Unpickler::readInstruction() + 0xa38 (0xfffff64f63d0 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::jit::Unpickler::run() + 0x80 (0xfffff64f7878 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #6: torch::jit::Unpickler::parse_ivalue() + 0x34 (0xfffff64f7a14 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: torch::jit::readArchiveAndTensors(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<std::function<c10::StrongTypePtr (c10::QualifiedName const&)> >, c10::optional<std::function<c10::intrusive_ptr<c10::ivalue::Object, c10::detail::intrusive_target_default_null_type<c10::ivalue::Object> > (c10::StrongTypePtr, c10::IValue)> >, c10::optional<c10::Device>, caffe2::serialize::PyTorchStreamReader&, c10::Type::SingletonOrSharedTypePtr<c10::Type> (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&), std::shared_ptr<torch::jit::DeserializationStorageContext>) + 0x31c (0xfffff64bc7fc in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: torch::jit::pickle_load(std::vector<char, std::allocator<char> > const&) + 0x230 (0xfffff64cd210 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #9: torch::pickle_load(std::vector<char, std::allocator<char> > const&) + 0x28 (0xfffff6892f10 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)

While with a .pt file I get:

terminate called after throwing an instance of 'c10::Error'
  what():  type_resolver_INTERNAL ASSERT FAILED at "../torch/csrc/jit/serialization/unpickler.cpp":683, please report a bug to PyTorch. 
Exception raised from readGlobal at ../torch/csrc/jit/serialization/unpickler.cpp:683 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xa0 (0xfffff7f0be30 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0xe8 (0xfffff7eedad4 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: torch::jit::Unpickler::readGlobal(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x76c (0xfffff64f138c in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::jit::Unpickler::readInstruction() + 0xe18 (0xfffff64f67b0 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::jit::Unpickler::run() + 0x80 (0xfffff64f7878 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::jit::Unpickler::parse_ivalue() + 0x34 (0xfffff64f7a14 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #6: torch::jit::readArchiveAndTensors(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<std::function<c10::StrongTypePtr (c10::QualifiedName const&)> >, c10::optional<std::function<c10::intrusive_ptr<c10::ivalue::Object, c10::detail::intrusive_target_default_null_type<c10::ivalue::Object> > (c10::StrongTypePtr, c10::IValue)> >, c10::optional<c10::Device>, caffe2::serialize::PyTorchStreamReader&, c10::Type::SingletonOrSharedTypePtr<c10::Type> (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&), std::shared_ptr<torch::jit::DeserializationStorageContext>) + 0x31c (0xfffff64bc7fc in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: torch::jit::pickle_load(std::vector<char, std::allocator<char> > const&) + 0x230 (0xfffff64cd210 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: torch::pickle_load(std::vector<char, std::allocator<char> > const&) + 0x28 (0xfffff6892f10 in /home/nvidia/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)

What I am doing wrong?

I don’t know what might be causing the issue, but could you post this error (including a minimal code snippet to reproduce the issue) in the linked GitHub issue, please?