How to do weights initialization in nn.ModuleList?


(Tao Jiang) #1

I wanna do weights initialization in an nn.ModuleList instance.
Here are my codes:

class ConvBlock(nn.Module):
	
	def __init__(self, input_channels, num_filters=128, conv_nums=3):
		super(ConvBlock, self).__init__()
		conv_blocks = [nn.Sequential(nn.LeakyReLU(0.1), nn.Conv2d(input_channels, num_filters, kernel_size=(3 ,3), padding=1))]
		for _ in range(conv_nums-1):
			conv_blocks.append(nn.Sequential(nn.LeakyReLU(0.1), nn.Conv2d(num_filters, num_filters, kernel_size=(3 ,3), padding=1)))
		self.conv_blocks = nn.ModuleList(conv_blocks) 

	def forward(self, input_tensor):
		out = [input_tensor]
		for conv in self.conv_blocks:
			input_tensor = conv(input_tensor)
			out.append(input_tensor)
		return out

and the weight intialization code I often used is

for m in self.modules():
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, sqrt(2. / n))

but it seems not worked for a complicated network structure. Could someone tell me how to solve this problem?


(Kaushal Paneri) #2

You can use apply() to initialize weights. Here is the discussion thread.

class ConvBlock(nn.Module):
	
	def __init__(self, input_channels, num_filters=128, conv_nums=3):
		super(ConvBlock, self).__init__()
		conv_blocks = [nn.Sequential(nn.LeakyReLU(0.1), nn.Conv2d(input_channels, num_filters, kernel_size=(3 ,3), padding=1))]
		for _ in range(conv_nums-1):
			conv_blocks.append(nn.Sequential(nn.LeakyReLU(0.1), nn.Conv2d(num_filters, num_filters, kernel_size=(3 ,3), padding=1)))
		self.conv_blocks = nn.ModuleList(conv_blocks) 
                self.conv_blocks.apply(self._init_weights)

	def forward(self, input_tensor):
		out = [input_tensor]
		for conv in self.conv_blocks:
			input_tensor = conv(input_tensor)
			out.append(input_tensor)
		return out
        def _init_weights(self, m):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                m.bias.data.fill_(0.01)