Determine AnyModule type, need to use get<> + catch?

I’m recreating the python method from “Deep Learning with Pytorch” p. 295, shown below.

I’m curious about the recommended c++ method for duplicating “type(m)”. As far as I can tell, given an AnyModule, the only way I can determine the underlying type is to attempt a cast via get<>. My c++ code is below.

Is there any way to determine the type wo suffering a throw/catch?

def _init_weights(self):
for m in self.modules():
if type(m) in {
nn.init.kaiming_normal_(, a=0, mode='fan_out', nonlinearity='relu',
if m.bias is not None:
fan_in, fan_out = \
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
    for (auto it = m_Model->begin(); it != m_Model->end(); ++it)
		// Can't find any way to check type of AnyModule except get<> with potential throw.
		// type_info() deleted function. 
		// auto info = it->type_info();
			// throws: "Attempted to cast module of type class torch::nn::Conv1dImpl to type class torch::nn::LinearImpl"
			auto m = it->get<torch::nn::Linear>();
			torch::nn::init::kaiming_normal_(m->, 0.0,torch::kFanOut, torch::kTanh);
			if (m->bias.numel())
				auto Fan = torch::nn::init::_calculate_fan_in_and_fan_out(m->;
				auto FanOut = std::get<1>(Fan);
				auto Bound = 1.0/sqrt(FanOut);
				torch::nn::init::normal_(m->bias, -Bound, Bound);
		catch (...)

You could try using dynamic_pointer_cast, for example

void InitWeights()
	for (auto& m : modules(false))
		auto conv3dPtr = std::dynamic_pointer_cast<torch::nn::Conv3dImpl>(m);

		if (conv3dPtr != nullptr)
			torch::nn::init::kaiming_normal_(conv3dPtr->weight, 0.0, torch::kFanOut, torch::kReLU);

		auto linearPtr = std::dynamic_pointer_cast<torch::nn::LinearImpl>(m);

		if (linearPtr != nullptr)
			torch::nn::init::kaiming_normal_(linearPtr->weight, 0.0, torch::kFanOut, torch::kReLU);
1 Like

Perfect. Thanks so much. My attempts at casting didn’t include the “Impl”.