Determine AnyModule type, need to use get<> + catch?

I’m recreating the python method from “Deep Learning with Pytorch” p. 295, shown below.

I’m curious about the recommended c++ method for duplicating “type(m)”. As far as I can tell, given an AnyModule, the only way I can determine the underlying type is to attempt a cast via get<>. My c++ code is below.

Is there any way to determine the type wo suffering a throw/catch?

def _init_weights(self):
for m in self.modules():
if type(m) in {
nn.Linear,
nn.Conv3d,
}:
nn.init.kaiming_normal_(
m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
)
if m.bias is not None:
fan_in, fan_out = \
nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
    for (auto it = m_Model->begin(); it != m_Model->end(); ++it)
	{
		// Can't find any way to check type of AnyModule except get<> with potential throw.
		// https://pytorch.org/cppdocs/api/classtorch_1_1nn_1_1_any_module.html
		// type_info() deleted function. 
		// auto info = it->type_info();
		try
		{
			// throws: "Attempted to cast module of type class torch::nn::Conv1dImpl to type class torch::nn::LinearImpl"
			auto m = it->get<torch::nn::Linear>();
			torch::nn::init::kaiming_normal_(m->weight.data(), 0.0,torch::kFanOut, torch::kTanh);
			if (m->bias.numel())
			{
				auto Fan = torch::nn::init::_calculate_fan_in_and_fan_out(m->weight.data());
				auto FanOut = std::get<1>(Fan);
				auto Bound = 1.0/sqrt(FanOut);
				torch::nn::init::normal_(m->bias, -Bound, Bound);
			}
			continue;
		}
		catch (...)
		{
		}

You could try using dynamic_pointer_cast, for example

void InitWeights()
{
	for (auto& m : modules(false))
	{
		auto conv3dPtr = std::dynamic_pointer_cast<torch::nn::Conv3dImpl>(m);

		if (conv3dPtr != nullptr)
		{
			torch::nn::init::kaiming_normal_(conv3dPtr->weight, 0.0, torch::kFanOut, torch::kReLU);
		}

		auto linearPtr = std::dynamic_pointer_cast<torch::nn::LinearImpl>(m);

		if (linearPtr != nullptr)
		{
			torch::nn::init::kaiming_normal_(linearPtr->weight, 0.0, torch::kFanOut, torch::kReLU);
		}
	}
}
1 Like

Perfect. Thanks so much. My attempts at casting didn’t include the “Impl”.