Question about net initialization

Hello. I want to ask a question about net initialization.
usually in the net’s init function there will be line like:

for m in self.modules():
            print(m)
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal(m.weight.data) # default mode is 'fan_in'
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

It seems that it will iterate recursively. If in my net there is a subnet which also has a batchnorm layer,but I want to init it differently, how should I do?
In a concrete situation, one part of the net is like vgg with batchnorm, them the scale of bn usually should be set to 1.
However, another part of my net has the structure residual connection like ResNet, then the bn in ResNet usually should be set to 0… How to deal with this situation?
Thanks!

You can save different parts as different attributes of your class, e.g.:

class Net(nn.Module):
  def __init__(...)
    ...
    self.vgg = ... # vgg part
    self.resnet = ... # resnet part

  def init_weights(...):
    # do something with self.vgg
    # do something else with self.resnet
1 Like