Weight initilzation

You first define your name check function, which applies selectively the initialisation.

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        xavier(m.weight.data)
        xavier(m.bias.data)

Then you traverse the whole set of Modules.

net = Net() # generate an instance network from the Net class
net.apply(weights_init) # apply weight init

And this is it. You just need to define the xavier() function.

14 Likes