In C++, a std::shared_ptr
devired from a base class can hold a subclass, but in the pimpl instanciation way of libtorch, I don’t see how to do the same. For example:
In the constructor of a theorical encoder class, some (libtorch) modules member can be from different type:
if (freq)
conv = register_module("conv",
torch::nn::Conv2d(
torch::nn::Conv2dOptions(chin, chout, { kernel_size, 1 })
.stride({ stride, 1 })
.padding({ padding, 0 })));
else
conv = register_module("conv",
torch::nn::Conv1d(
torch::nn::Conv1dOptions(chin, chout, kernel_size)
.stride(stride)
.padding(padding)));
or
if (norm)
norm1 = register_module("norm1",
torch::nn::GroupNorm(
torch::nn::GroupNormOptions(norm_groups, chout)));
else
norm1 = register_module("norm1",
torch::nn::Identity());
In a first instance, I wrote class member like that:
torch::nn::Module conv;
torch::nn::Module norm1;
Do not work, probably because the pimpl-like instanciation method (macro for module holder). Sorry, i’m really beginner with the libtorch api so I can be wrong.
The other idea is to use std::variant
but that introduce a complexity and broke the direct access to the process by the operator()
.
Is there a simple way to hold variant like module in a class member?