How to use torch::nn::init::calculate_gain() in C++

Hi everyone,
I want to use torch::nn::init::calculate_gain() in C++, it defined like following:
TORCH_API double calculate_gain(NonlinearityType nonlinearity, double param = 0.01);
the parameter “NonlinearityType” defined like:

namespace torch {
namespace nn {
namespace init {
using NonlinearityType = c10::variant<
  enumtype::kLinear,
  enumtype::kConv1D,
  enumtype::kConv2D,
  enumtype::kConv3D,
  enumtype::kConvTranspose1D,
  enumtype::kConvTranspose2D,
  enumtype::kConvTranspose3D,
  enumtype::kSigmoid,
  enumtype::kTanh,
  enumtype::kReLU,
  enumtype::kLeakyReLU
>;

I want to write code like nn.init.calculate_gain('conv2d') in C++, what should I put in the torch::nn::init::calculate_gain() ?

Thanks

1 Like

I would just try: torch::enumtype::kConv2D

The C++ enum convention is that the option starts with “k”, and the naming usually hints what it is about.

we need to use it as

torch::enumtype::kConv2D myType;
calculate_gain(myType);

Or

calculate_gain( torch::enumtype::kConv2D() ); //constructs the temporary object in the arguments

This is because if we look in the definition, we’ll see how it’s wrapped in the preprocessor:

if you peek into the TORCH_ENUM_DECLARE, you will actually see that it’s declaring structs.