C++ model pointer that supports both clone() and forward()?

I’m not sure if this is a generic C++ question or a PyTorch question.

I’m building a C++ program using libtorch that defines several empty neural networks, and the user chooses via command line options can choose which model to train with the given dataset. The models are pre-defined in structures, but I’d like the training loop to be generic and work for any chosen model.

Models are subclasses of Module. But “forward()” isn’t defined in the Module base class, only the models themselves. So to get a “generic” model pointer so I can do “model->forward(x)” in the training loop, regardless of the model chosen, I make each model inherit from a custom generic “MyModel” class which inherits from PyTorch’s Model and defines a virtual forward() entry that’s overridden by the real models. I then pass the “MyModel” pointer to the training loop, and this works fine.

The problem comes in when I want to do a “model->clone()” in my training loop to keep an updated copy of the parameters of best performing model. clone() only works if the model is inherited from the templated Cloneable class, and Cloneable inherits from Module, which cuts out my middle-man class I use to get a model-agnostic pointer that has a forward() function. I can’t inherit from both because they both inherit from Module, giving two paths to that base class.

So how can i create a model-agnostic “generic” pointer to any of my declared models that includes both the ability to call model->forward() and model->clone()? I feel like this should be obvious to a seasoned c++ programmer, but as a C/Python/Rust programmer I’m a bit stumped.

I’m also unsure what needs to go in reset(), if anything. But that’s a separate question.

Any help?

Actual code without clone() support is here (with the real class names) if you’re interested:

I don’t know if you’re able to achieve that with Clonable, in my code I just handle cloning myself using save / load, e.g.:

class IArchitecture : public torch::nn::Module
{
public:

	virtual torch::Tensor forward(torch::Tensor x) = 0;
	
	virtual void reset_parameters() = 0;

	virtual std::shared_ptr<IArchitecture> custom_clone() = 0;
};

typedef std::shared_ptr<IArchitecture> IArchitecturePtr;

template<class C>
struct IArchitectureClonable : public IArchitecture
{
	virtual std::shared_ptr<IArchitecture> custom_clone() override
	{
		auto retval = std::make_shared<C>();

		std::string data;

		{
			std::ostringstream oss;
			torch::serialize::OutputArchive archive;

			this->save(archive);
			archive.save_to(oss);
			data = oss.str();
		}

		{
			std::istringstream iss(data);
			torch::serialize::InputArchive archive;
			archive.load_from(iss);
			retval->load(archive);
		}
		
		return retval;
	}
};

class SimpleMLPImpl : public IArchitectureClonable<SimpleMLPImpl>
{
public:
	torch::nn::Linear fc1{ nullptr };
	torch::nn::Linear fc2{ nullptr };

	SimpleMLPImpl()
	{
		fc1 = register_module("fc1", torch::nn::Linear(32, 64));
		fc2 = register_module("fc2", torch::nn::Linear(64, 1));
	}

	virtual torch::Tensor forward(torch::Tensor x) override
	{
		x = torch::leaky_relu(fc1->forward(x));
		return fc2->forward(x);
	}
	
	void reset_parameters() override
	{
		fc1->reset_parameters();
		fc2->reset_parameters();
	}
};

void main(int argc, char** argv)
{
	try
	{
		IArchitecturePtr model = std::make_shared<SimpleMLPImpl>();
		IArchitecturePtr cloned = model->custom_clone();
				
		auto testInput = torch::randn({ 1, 32 });

		// same output of both clone and original
		std::cout << model->forward(testInput) << std::endl;
		std::cout << cloned->forward(testInput) << std::endl;

		// reset parameters of original model
		model->reset_parameters();

		// different output
		std::cout << model->forward(testInput) << std::endl;
		std::cout << cloned->forward(testInput) << std::endl;
	}
	catch (std::runtime_error& e)
	{
		std::cout << e.what() << std::endl;
	}
	catch (const c10::Error& e)
	{
		std::cout << e.msg() << std::endl;
	}

	system("PAUSE");
}
1 Like

Ahh… stringstream. Clever, I approve :slight_smile:. That should work for me as well. Thanks!

I would have spent hours trying to make clone work and assuming I was just crap at C++… which I am, but at least in this case it appears it’s not blatantly obvious.

I tried doing the same thing with Clonable. Something like this seems to work (you’re right - it is not that obvious :grinning: , you have to get used to the CRTP pattern )

class IArchitecture 
{
   public:
	virtual torch::Tensor forward(torch::Tensor x) = 0;
	virtual void reset_parameters() = 0;
	virtual void reset() = 0;
	virtual std::shared_ptr<IArchitecture> clone() = 0;
};

typedef std::shared_ptr<IArchitecture> IArchitecturePtr;

template<class C>
struct IArchitectureClonable : public torch::nn::Cloneable<C>, IArchitecture
{
	virtual std::shared_ptr<IArchitecture> clone() override
	{
		return std::dynamic_pointer_cast<IArchitecture>(torch::nn::Cloneable<C>::clone());
	}
};

And you also have to implement reset method, which by looking at Linear implementation basically does the things that would usually be in the constructor (register_parameter, module …)

SimpleMLPImpl()
{
    this->reset();
}

virtual void reset() override
{
    fc1 = register_module("fc1", torch::nn::Linear(32, 64));
    fc2 = register_module("fc2", torch::nn::Linear(64, 1));
}
1 Like