You first define your name check function, which applies selectively the initialisation.
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
xavier(m.weight.data)
xavier(m.bias.data)
Then you traverse the whole set of Modules
.
net = Net() # generate an instance network from the Net class
net.apply(weights_init) # apply weight init
And this is it. You just need to define the xavier()
function.