Init parameters - weight_init not defined


(Fabrice noreils) #1

Dear All,

After reading different threads, I implemented a method which considered as the “standard one” to initialize the paramters ol all layers (see code below):

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

def __init__(self):
    
    super(Net, self).__init__()
    
    self.conv1 = nn.Conv2d(1, 32, 5)
    self.conv2 = nn.Conv2d(32, 64, 5)
    self.conv3 = nn.Conv2d(64, 128, 5)
    self.conv4 = nn.Conv2d(128, 256, 3)
    self.conv5 = nn.Conv2d(256, 256, 2)
    self.fc1 = nn.Linear(6400, 6400)
    self.fc2 = nn.Linear(6400, 6400)
    self.fc3 = nn.Linear(6400, 136)
    
def weights_init(m): 
    if isinstance(m, nn.Conv2d): 
        nn.init.xavier_normal_(m.weight.data) 
        nn.init.xavier_normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight,1)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm1d):
        nn.init.constant_(m.weight,1)
        nn.init.constant_(m.bias, 0)
            
## feedforward behavior
def forward(self, x):
   # check whether tanh is preferable here?
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.dropout(x, p = 0.1)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = F.dropout(x, p = 0.2)
    x = F.relu(F.max_pool2d(self.conv3(x), 2))
    x = F.dropout(x, p = 0.3)
    x = F.relu(F.max_pool2d(self.conv4(x), 2))
    x = F.dropout(x, p = 0.4)
    x = F.relu(F.avg_pool2d(self.conv5(x), 2))
    x = F.dropout(x, p = 0.5)
    x = x.view(x.size(0), -1)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, p = 0.5)
    x = F.relu(self.fc2(x))
    x = F.dropout(x, p = 0.6)
    y_pred = F.tanh(self.fc3(x))
    return y_pred

but when i enter:
net.apply(weights_init)

I got the follwoing error:

NameError Traceback (most recent call last)
in ()
98 net = Net()
99 #BN_net.weights_init()
–> 100 net.apply(weights_init)
101 print (net)

NameError: name ‘weights_init’ is not defined

Can someone can tell me what is wrong here?

Thank you very much


(Juan F Montesinos) #2

weights_init is defined inside the class, you are trying (I think, u put no code) to call it from outside the class.
You should call
net.apply(net.weights_init)
But it makes no sense to define it inside the class.


(Fabrice noreils) #3

@JuanFMontesinos Ah I see this function mut be defined outside of the class
And yes it works
thank you