(libtorch) How to save model in MNIST cpp example?

I think, the example was written prior to the stable release of libtorch. The way you would implement the torch::nn::Module now is as follows

struct NetImpl : torch::nn::Module {       // replaced Net by NetImpl
  NetImpl()                                // replaced Net by NetImpl
      : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),
        conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),
        fc1(320, 50),
        fc2(50, 10) {
    register_module("conv1", conv1);
    register_module("conv2", conv2);
    register_module("conv2_drop", conv2_drop);
    register_module("fc1", fc1);
    register_module("fc2", fc2);
  }

  torch::Tensor forward(torch::Tensor x) {
    x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
    x = torch::relu(
        torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
    x = x.view({-1, 320});
    x = torch::relu(fc1->forward(x));
    x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());
    x = fc2->forward(x);
    return torch::log_softmax(x, /*dim=*/1);
  }

  torch::nn::Conv2d conv1;
  torch::nn::Conv2d conv2;
  torch::nn::FeatureDropout conv2_drop;
  torch::nn::Linear fc1;
  torch::nn::Linear fc2;
};

TORCH_MODULE(Net); // creates module holder for NetImpl

TORCH_MODULE(Net) creates a module holder, which is a std::shared_ptr<NetImpl>. This will enable you to call

torch::save(model, "model.pt");

and

torch::load(model, "model.pt");

You need to replace function calls on model.function() by model->function() then. This should also enable you to call model(input) instead of model.forward(input).

As can be read in the DCGAN Tutorial

For example, the serialization API ( torch::save and torch::load ) only supports module holders (or plain shared_ptr )

5 Likes