I’m not sure if this is a generic C++ question or a PyTorch question.
I’m building a C++ program using libtorch that defines several empty neural networks, and the user chooses via command line options can choose which model to train with the given dataset. The models are pre-defined in structures, but I’d like the training loop to be generic and work for any chosen model.
Models are subclasses of Module. But “forward()” isn’t defined in the Module base class, only the models themselves. So to get a “generic” model pointer so I can do “model->forward(x)” in the training loop, regardless of the model chosen, I make each model inherit from a custom generic “MyModel” class which inherits from PyTorch’s Model and defines a virtual forward() entry that’s overridden by the real models. I then pass the “MyModel” pointer to the training loop, and this works fine.
The problem comes in when I want to do a “model->clone()” in my training loop to keep an updated copy of the parameters of best performing model. clone() only works if the model is inherited from the templated Cloneable class, and Cloneable inherits from Module, which cuts out my middle-man class I use to get a model-agnostic pointer that has a forward() function. I can’t inherit from both because they both inherit from Module, giving two paths to that base class.
So how can i create a model-agnostic “generic” pointer to any of my declared models that includes both the ability to call model->forward() and model->clone()? I feel like this should be obvious to a seasoned c++ programmer, but as a C/Python/Rust programmer I’m a bit stumped.
I’m also unsure what needs to go in reset(), if anything. But that’s a separate question.
Actual code without clone() support is here (with the real class names) if you’re interested: