I think, the example was written prior to the stable release of libtorch. The way you would implement the torch::nn::Module
now is as follows
struct NetImpl : torch::nn::Module { // replaced Net by NetImpl
NetImpl() // replaced Net by NetImpl
: conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),
conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),
fc1(320, 50),
fc2(50, 10) {
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv2_drop", conv2_drop);
register_module("fc1", fc1);
register_module("fc2", fc2);
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
x = torch::relu(
torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
x = x.view({-1, 320});
x = torch::relu(fc1->forward(x));
x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());
x = fc2->forward(x);
return torch::log_softmax(x, /*dim=*/1);
}
torch::nn::Conv2d conv1;
torch::nn::Conv2d conv2;
torch::nn::FeatureDropout conv2_drop;
torch::nn::Linear fc1;
torch::nn::Linear fc2;
};
TORCH_MODULE(Net); // creates module holder for NetImpl
TORCH_MODULE(Net)
creates a module holder, which is a std::shared_ptr<NetImpl>
. This will enable you to call
torch::save(model, "model.pt");
and
torch::load(model, "model.pt");
You need to replace function calls on model.function()
by model->function()
then. This should also enable you to call model(input)
instead of model.forward(input)
.
As can be read in the DCGAN Tutorial
For example, the serialization API (
torch::save
andtorch::load
) only supports module holders (or plainshared_ptr
)