What is the common base class for my module (holders) in libtorch?

Consider the following snippet of C++:


struct Simple1Impl : public torch::nn::Module
{
    /* Constructor and forward */
};
TORCH_MODULE( Simple1 );

struct Simple2Impl : public torch::nn::Module
{
    /* Constructor and forward */
};
TORCH_MODULE( Simple2 );

What is the common base class for which I can write things such as:

WhatIsThisType model = condition ? (SomeCast) Simple1() : (SomeCast) Simple2();

I would like to dynamically decide the model. I have tried std::shared_ptr<torch::nn::Module> and torch::nn::AnyModule, both of which cause compilation errors.

As a reminder, the macro definition is this:

/// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a
/// wrapper over a `std::shared_ptr<ImplType>`.
/// `Impl` is a type alias for `ImplType` which provides a way to call static
/// method of `ImplType`.
#define TORCH_MODULE_IMPL(Name, ImplType)                              \
  class Name : public torch::nn::ModuleHolder<ImplType> { /* NOLINT */ \
   public:                                                             \
    using torch::nn::ModuleHolder<ImplType>::ModuleHolder;             \
    using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType;                    \
  }

/// Like `TORCH_MODULE_IMPL`, but defaults the `ImplType` name to `<Name>Impl`.
#define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)

That would make Simple1 a class derived from public::nn::ModuleHolder<Simple1Impl>. If the latter class behaves like std::shared_ptr<Simple1Impl>, then one should conceivably be able to cast the latter public::nn::ModuleHolder<torch::nn::Module>, though that may depend on the casting operators.

For the std::shared_ptr class, the construction via derived pointers seems to work because the constructor is templated:
https://en.cppreference.com/w/cpp/memory/shared_ptr/shared_ptr

However, that behavior is not provided by the wrapper class ModuleHolder, which seems to only accept assignments or copy/move constructions with one specific class.

So that possible seems out, for now. I don’t know whether that omission is by deliberate design.

I am sharing a makeshift dont-know-better solution for anyone who happens to wonder about the same thing.

It seems that a one-item torch::nn::Sequential can serve as a universal container for modules. Hence you might write:

torch::nn::Sequential model = condition ? torch::nn::Sequential( Simple1() ):  torch::nn::Sequential( Simple2() );