Hi everyone,
I want to use torch::nn::init::calculate_gain() in C++, it defined like following:
TORCH_API double calculate_gain(NonlinearityType nonlinearity, double param = 0.01);
the parameter “NonlinearityType” defined like:
namespace torch {
namespace nn {
namespace init {
using NonlinearityType = c10::variant<
enumtype::kLinear,
enumtype::kConv1D,
enumtype::kConv2D,
enumtype::kConv3D,
enumtype::kConvTranspose1D,
enumtype::kConvTranspose2D,
enumtype::kConvTranspose3D,
enumtype::kSigmoid,
enumtype::kTanh,
enumtype::kReLU,
enumtype::kLeakyReLU
>;
I want to write code like nn.init.calculate_gain('conv2d')
in C++, what should I put in the torch::nn::init::calculate_gain() ?
Thanks