Consider the following snippet of C++:
struct Simple1Impl : public torch::nn::Module
{
/* Constructor and forward */
};
TORCH_MODULE( Simple1 );
struct Simple2Impl : public torch::nn::Module
{
/* Constructor and forward */
};
TORCH_MODULE( Simple2 );
What is the common base class for which I can write things such as:
WhatIsThisType model = condition ? (SomeCast) Simple1() : (SomeCast) Simple2();
I would like to dynamically decide the model. I have tried std::shared_ptr<torch::nn::Module>
and torch::nn::AnyModule
, both of which cause compilation errors.
As a reminder, the macro definition is this:
/// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a
/// wrapper over a `std::shared_ptr<ImplType>`.
/// `Impl` is a type alias for `ImplType` which provides a way to call static
/// method of `ImplType`.
#define TORCH_MODULE_IMPL(Name, ImplType) \
class Name : public torch::nn::ModuleHolder<ImplType> { /* NOLINT */ \
public: \
using torch::nn::ModuleHolder<ImplType>::ModuleHolder; \
using Impl TORCH_UNUSED_EXCEPT_CUDA = ImplType; \
}
/// Like `TORCH_MODULE_IMPL`, but defaults the `ImplType` name to `<Name>Impl`.
#define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)
That would make Simple1
a class derived from public::nn::ModuleHolder<Simple1Impl>
. If the latter class behaves like std::shared_ptr<Simple1Impl>
, then one should conceivably be able to cast the latter public::nn::ModuleHolder<torch::nn::Module>
, though that may depend on the casting operators.