I tried implementing the approach described here: https://stackoverflow.com/questions/49433936/how-to-initialize-weights-in-pytorch
i.e. creating a function that could be passed to Apply on all modules part of the sequential net. Here is the code:
=============================================
void Init_Weights(torch::nn::Module& m)
{
if (typeid(m) == typeid(torch::nn::Linear))
{
torch::nn::init::xavier_normal_(m->weight);
torch::nn::init::constant_(m->bias, 0.01);
}
}
int main() {
torch::nn::Sequential XORModel(
torch::nn::Linear(2, 3),
torch::nn::Functional(torch::tanh),
torch::nn::Linear(3, 1),
torch::nn::Functional(torch::sigmoid));
XORModel->apply(Init_Weights);
}
=================================================
The problem is that the code will not compile. Torch requires that definition of Init_Weights should have ‘torch::nn::Module& m’ as input. In this case ‘m->weight’ could not be resolved as type Module does not have ‘weight’
If I change definition of Init_Weights so that its input is of ‘torch::nn::Linear& m’ than Init_weights could not be passed to Apply.
Is there a way to initialize (Xavier normal) all weight in Linear modules part of Sequential module?