Weights init for new classes

I’m working on a CNN and I’m experimenting some modifications within it. To illustrate my issue, I will only explain briefly what I call default form and the modified form.

The default form uses pretty much built-in classes from PyTorch, such as Conv2d, BatchNorm2d, and so on. In the modified form, I intent to experiment using convolutions in a different way. I will call this as parallel convolutions. In these I do dilated convolutions (1, 2 and 3 dilation values) with the same input, stack its output and process it through a final usual convolution (dilation=1). To do so, I created a new class, which handles all these convolutions, making it simpler to be called afterwards.

The problem is: when I use the default form, it’s easy to initialize weights with something like this

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

but I can’t get it to work with my new class. I need some way to initialize the convolutions which I built into this new class of mine. Do you have any insights for me?

The following code shows the class I created and it is intended to be used instead of nn.Conv2d in my CNN:

class MSConv2d(nn.Module):
  def __init__(self, in_channels, num_filters, kernel_size, use_bias, stride=1, padding='zero'):
    """Builds a multi-scale convolution block (https://link)
    """
    super(MSConv2d, self).__init__()
    dilation1 = [get_same_padding_layer(kernel_size, stride=1, mode=padding),
                 nn.Conv2d(in_channels, num_filters, kernel_size=kernel_size, stride=1, dilation=1, bias=use_bias)]
    dilation2 = [get_same_padding_layer(kernel_size, stride=1, mode=padding),
                 nn.Conv2d(in_channels, num_filters, kernel_size=kernel_size, stride=1, dilation=2, bias=use_bias)]
    dilation3 = [get_same_padding_layer(kernel_size, stride=1, mode=padding),
                 nn.Conv2d(in_channels, num_filters, kernel_size=kernel_size, stride=1, dilation=3, bias=use_bias)]
    final_conv = [get_same_padding_layer(kernel_size, stride=1, mode=padding),
                  nn.Conv2d(in_channels, num_filters, kernel_size=kernel_size, stride=1, bias=use_bias)]
    self.dilationBlock1 = nn.Sequential(*dilation1)
    self.dilationBlock2 = nn.Sequential(*dilation2)
    self.dilationBlock3 = nn.Sequential(*dilation3)
    self.convBlock = nn.Sequential(*final_conv)

  def forward(self, x):
    output1 = self.dilationBlock1(x)
    output2 = self.dilationBlock2(x)
    output3 = self.dilationBlock3(x)
    output = torch.stack((output1, output2, output3), dim=1)
    output = self.convBlock(output)
    return output

Hello,

I think you could try it on model.named_modules(), it will get the sub-module in the whole network recursively including the sub-module in sequential.

import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

        dilation1 = [nn.Conv2d(3,3,3, padding=1, bias=False),
                     nn.Conv2d(3,3,3, padding=1, bias=False)]
        self.conv1 = nn.Sequential(*dilation1)

    def forward(self, input):

        output = self.conv1(input)
        return output

model = MyModule()
for module in model.named_modules():
    print(module)
# output
('', MyModule(
  (conv1): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
))
('conv1', Sequential(
  (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
))
('conv1.0', Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
('conv1.1', Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))
1 Like

Thanks for the input MariosOreo, and I’m sorry for taking this long to answer. I had implemented a workaround, but I’ll be trying your suggestion soon, as it is more sophisticated and should work generally better for other models.