How to hold a nn:Module in class member

In C++, a std::shared_ptr devired from a base class can hold a subclass, but in the pimpl instanciation way of libtorch, I don’t see how to do the same. For example:

In the constructor of a theorical encoder class, some (libtorch) modules member can be from different type:


		if (freq)
			conv = register_module("conv", 
				   torch::nn::Conv2d(
				   torch::nn::Conv2dOptions(chin, chout, { kernel_size, 1 })
				  .stride({ stride, 1 })
				  .padding({ padding, 0 })));
		else
			conv = register_module("conv", 
				   torch::nn::Conv1d(
				   torch::nn::Conv1dOptions(chin, chout, kernel_size)
				  .stride(stride)
				  .padding(padding)));

or

        if (norm)
			norm1 = register_module("norm1",
					torch::nn::GroupNorm(
					torch::nn::GroupNormOptions(norm_groups, chout)));
		else
			norm1 = register_module("norm1",
					torch::nn::Identity());

In a first instance, I wrote class member like that:

torch::nn::Module conv;
torch::nn::Module norm1;

Do not work, probably because the pimpl-like instanciation method (macro for module holder). Sorry, i’m really beginner with the libtorch api so I can be wrong.

The other idea is to use std::variant but that introduce a complexity and broke the direct access to the process by the operator().

Is there a simple way to hold variant like module in a class member?

Self-Reply:

The solution is pretty trivial, just use torch::nn::AnyModule as class member.