Sample code to stop and resume training using C++ API

Sometimes we face this situation: we have a large network, so we want to train it for a while, save the network, and resume training another day. It took me a while to implement this in C++, thus sharing the code here.

// your customized net:
class MyNetImpl: public torch::nn::Module {

}

// wrap this into a Torch module so that it can be restored later via torch::load
TORCH_MODULE(MyNet);

// In the training codes:

if (epoch %100 == 0) {
torch::save(myNet, “myNet.pt”)
// If you use a stateful optimizer here such as ADAM, also, please save the optimizer.
}

// Next time when the program restart, you wanna resume training:

auto myNet = MyNet(make_shared())
if (“myNet.pt” file exists) {
torch::load(myNet, “myNet.pt”)
}

That is it.

2 Likes