I’m recreating the python method from “Deep Learning with Pytorch” p. 295, shown below.
I’m curious about the recommended c++ method for duplicating “type(m)”. As far as I can tell, given an AnyModule, the only way I can determine the underlying type is to attempt a cast via get<>. My c++ code is below.
Is there any way to determine the type wo suffering a throw/catch?
def _init_weights(self):
for m in self.modules():
if type(m) in {
nn.Linear,
nn.Conv3d,
}:
nn.init.kaiming_normal_(
m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
)
if m.bias is not None:
fan_in, fan_out = \
nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
for (auto it = m_Model->begin(); it != m_Model->end(); ++it)
{
// Can't find any way to check type of AnyModule except get<> with potential throw.
// https://pytorch.org/cppdocs/api/classtorch_1_1nn_1_1_any_module.html
// type_info() deleted function.
// auto info = it->type_info();
try
{
// throws: "Attempted to cast module of type class torch::nn::Conv1dImpl to type class torch::nn::LinearImpl"
auto m = it->get<torch::nn::Linear>();
torch::nn::init::kaiming_normal_(m->weight.data(), 0.0,torch::kFanOut, torch::kTanh);
if (m->bias.numel())
{
auto Fan = torch::nn::init::_calculate_fan_in_and_fan_out(m->weight.data());
auto FanOut = std::get<1>(Fan);
auto Bound = 1.0/sqrt(FanOut);
torch::nn::init::normal_(m->bias, -Bound, Bound);
}
continue;
}
catch (...)
{
}