I am trying to implement serialisation functionality into my libtorch code. As my top-level class comprises quite many different member attributes I have implemented two user-callable functions
inline void save(const std::string& filename,
const std::string& key="") const
{
torch::serialize::OutputArchive archive;
write(archive, key).save_to(filename);
}
inline void load(const std::string& filename,
const std::string& key="")
{
torch::serialize::InputArchive archive;
archive.load_from(filename);
read(archive, key);
}
Furthermore, I have implemented write
and read
functions of the form
inline torch::serialize::OutputArchive& write(torch::serialize::OutputArchive& archive,
const std::string& key="") const
{
archive.write(key+"dim", torch::full({1}, dim_));
geo_.write(archive, key+"geo");
rhs_.write(archive, key+"rhs");
sol_.write(archive, key+"sol");
torch::serialize::OutputArchive archive_net;
net_->save(archive_net);
archive.write(key+"net", archive_net);
torch::serialize::OutputArchive archive_opt;
opt_.save(archive_opt);
archive.write(key+"opt", archive_opt);
return archive;
}
inline torch::serialize::InputArchive& read(torch::serialize::InputArchive& archive,
const std::string& key="")
{
torch::Tensor tensor;
archive.read(key+"dim", tensor);
if (tensor.item<int64_t>() != dim_)
throw std::runtime_error("dim mismatch");
geo_.read(archive, key+"geo");
rhs_.read(archive, key+"rhs");
sol_.read(archive, key+"sol");
torch::serialize::InputArchive archive_net;
archive.read(key+"net", archive_net);
net_->load(archive_net);
torch::serialize::InputArchive archive_opt;
archive.read(key+"opt", archive_opt);
opt_.load(archive_opt);
return archive;
}
Here, net_
is of type MyGenerator
using the generator approach described here.
I have the following problems:
- The code runs up to the point that
opt_
is loaded, which gives rise to the exceptionlibc++abi: terminating with uncaught exception of type c10::Error: loaded state dict contains a parameter group that has a different size than the optimizer's parameter group Exception raised from serialize at /tmp/libtorch-20220312-61658-dj98dd/torch/csrc/api/include/torch/optim/serialize.h:170 (most recent call first): frame #0: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 92 (0x10040d4c0 in libc10.dylib) frame #1: void torch::optim::serialize<torch::optim::AdamParamState, torch::optim::AdamOptions>(torch::serialize::InputArchive&, torch::optim::Optimizer&) + 832 (0x108880be8 in libtorch_cpu.dylib) frame #2: torch::optim::Adam::load(torch::serialize::InputArchive&) + 108 (0x10887f21c in libtorch_cpu.dylib)
- The
net_
object is still empty afternet_->load(archive_net)
.
Any help on any of the above issues is highly appreciated.
Kind regards,
Matthias