=================================================
The problem is that the code will not compile. Torch requires that definition of Init_Weights should have ‘torch::nn::Module& m’ as input. In this case ‘m->weight’ could not be resolved as type Module does not have ‘weight’
If I change definition of Init_Weights so that its input is of ‘torch::nn::Linear& m’ than Init_weights could not be passed to Apply.
Is there a way to initialize (Xavier normal) all weight in Linear modules part of Sequential module?
Just to update - I found a solution. See code below:
void Init_Weights(torch::nn::Module& m)
{
if ((typeid(m) == typeid(torch::nn::LinearImpl)) || (typeid(m) == typeid(torch::nn::Linear))) {
auto p = m.named_parameters(false);
auto w = p.find("weight");
auto b = p.find("bias");
if (w != nullptr) torch::nn::init::xavier_uniform_(*w);
if (b != nullptr) torch::nn::init::constant_(*b, 0.01);
}
In case anyone runs across this particular question again, the following should work as a simple solution:
void xavier_init(torch::nn::Module& module) {
torch::NoGradGuard noGrad;
if (auto* linear = module.as<torch::nn::Linear>()) {
torch::nn::init::xavier_normal_(linear->weight);
torch::nn::init::constant_(linear->bias, 0.01);
}
}
Then for any Linear module or module with Linear submodules, you can just initialize via module->apply(xavier_init). I think this is basically what @cheggars was suggesting with their second response.
The implementations in torch.nn.init also rely on no-grad mode when initializing the parameters as to avoid autograd tracking when updating the intialized parameters in-place.