Initializing weights of a custom Conv layer module

Hey all,

I have the following custom convolutional module that i initialize the weights using nn.Parameters:

class DilatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(DilatedConv, self).__init__()
        # Initialize kernel
        self.kernel = torch.randn(out_channels, in_channels,  kernel_size, kernel_size)
        # Register the trainable parameters
        self.weight = nn.Parameter(self.kernel)
        self.bias = nn.Parameter(torch.randn(out_channels))

    def forward(self, x, stride, padding, dilation):
        # Do a functional call so we can 
        # use the same weights but different arguments
        return F.conv2d(x, self.weight, bias=self.bias, 
                    stride=stride, padding=padding, dilation=dilation
                )

class SDCLayer(nn.Module):
    def __init__(self, input_size, n_conv, kernel_sizes, n_kernels, dilations):
        super(SDCLayer, self).__init__()
        self.input_size = input_size
        self.n_conv = n_conv
        self.kernel_sizes = kernel_sizes
        self.n_kernels = n_kernels
        self.dilations = dilations

        self.dilated_conv = DilatedConv(self.input_size, self.n_kernels, self.kernel_sizes)
        self.elu = nn.ELU()


    def weights_init_normal():
        pass


    def forward(self, x):

        # The convolutions are sharing weights
        # so just assign new values to it
        # for each iteration
        sdc = torch.tensor([]).to("cuda" if torch.cuda.is_available() else "cpu")
        for i in range(0, self.n_conv):
            x_d = self.dilated_conv(x, stride=1, padding='same', dilation=self.dilations[i])
            sdc = torch.cat((sdc, x_d), dim=1)

        sdc = self.elu(sdc)

        return sdc

As you can see I’m initializing the weights with torch.randn, even though I think it’s causing me some problems, because my model ends up not learning.

I’ve found here that the problem could be wrong weight initialization, so I wanted to initialize the custom Conv layer correctly. How can I do this?

Thank you!

torch.nn.init will have most of the typically use initialization methods.

For your case, try this:

nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

# Bias
fan_in = self.in_channels * self.kernel_size * self.kernel_size
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)

References:

Thanks for your reply! I added these lines to my code and it seems that works fine. Even though, the problem that I mentioned with my network still happens. My network is not learning (or learning to predict just one of the three classes). Do you have any idea that can make it happen? I checked my training code with a simpler network and it works well. The problem apparently is with this custom convolutional network.

This is the full code of the network that I’m using:

class DilatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(DilatedConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        # Initialize kernel
        self.kernel = torch.randn(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
        # Register the trainable parameters
        self.weight = nn.Parameter(self.kernel)
        self.bias = nn.Parameter(torch.randn(out_channels))
        # Initialize the weights
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        # Bias
        fan_in = self.in_channels * self.kernel_size * self.kernel_size
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x, stride, padding, dilation):
        # Do a functional call so we can 
        # use the same weights but different arguments
        return F.conv2d(x, self.weight, bias=self.bias, 
                    stride=stride, padding=padding, dilation=dilation
                )

class SDCLayer(nn.Module):
    def __init__(self, input_size, n_conv, kernel_sizes, n_kernels, dilations):
        super(SDCLayer, self).__init__()
        self.input_size = input_size
        self.n_conv = n_conv
        self.kernel_sizes = kernel_sizes
        self.n_kernels = n_kernels
        self.dilations = dilations

        self.dilated_conv = DilatedConv(self.input_size, self.n_kernels, self.kernel_sizes)
        self.elu = nn.ELU()


    def forward(self, x):

        # The convolutions are sharing weights
        # so just assign new values to it
        # for each iteration
        sdc = torch.tensor([]).to("cuda" if torch.cuda.is_available() else "cpu")
        for i in range(0, self.n_conv):
            x_d = self.dilated_conv(x, stride=1, padding='same', dilation=self.dilations[i])
            sdc = torch.cat((sdc, x_d), dim=1)

        sdc = self.elu(sdc)

        return sdc
        

class SDCNetwork(nn.Module):
    def __init__(self, num_layers, input_size, n_conv, kernel_sizes, n_kernels, dilations):
        super(SDCNetwork, self).__init__()
        self.input_size = input_size
        self.n_conv = n_conv
        self.kernel_sizes = kernel_sizes
        self.n_kernels = n_kernels
        self.dilations = dilations
        self.num_classes = 3
        
        self.features = nn.Sequential()
        # Iterate through the number of layers
        for i in range(num_layers):
            self.features.add_module(
                'sdc'+str(i),
                SDCLayer(input_size=self.input_size[i],
                            n_conv=self.n_conv,
                            kernel_sizes=self.kernel_sizes,
                            n_kernels=self.n_kernels[i],
                            dilations=self.dilations
                ) 
            )

        self.conv_head = nn.Conv2d(self.n_kernels[-1]*self.n_conv, 512, kernel_size=1)

        # Class activation map layer similarly to ACOL paper output
        self.cam = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding='same'),
            nn.SiLU(),
            nn.Conv2d(256, self.num_classes, kernel_size=1, padding='same'),
            nn.SiLU()
        )

        # Global average pooling
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)


    def forward(self, x1, x2):
        x1, x2 = self.features(x1), self.features(x2)

        x = torch.cat((x1, x2), dim=2)

        x = self.conv_head(x)

        x = self.cam(x)

        x = self.gap(x)

        x = torch.flatten(x, 1)

        return x

As you can see I’m stacking N Dilated Convolutional Block (and inside this block it uses the same weight and different dilations for different inputs). After that I compute the cam + gap to classify into 3 different classes

Nothing looks blatantly wrong code-wise. I tested the custom conv module and it seems gradients are accruing and seem reasonable enough.

Is there a reason why you’re sharing the same weights for each dilation? This concept seems very strange to me. I could see the kernels being weird because they’re trying to operate locally (dilation=1) and also being told to operate more globally (dilation increasing).

Thanks for the reply! The ideia is mentioned in the paper https://arxiv.org/pdf/1904.03076.pdf section 3.2.

Thus, it is reasonable to share weights between the parallel convolutions within one SDC block. The only requirement is that the parallel convolutions are of the same shape. By sharing weights, the amount of parameters gets divided by the number of parallel convolutions (factor 4 in our case). This allows to construct very light-weight feature networks with a comparatively large receptive field.

Maybe my implementation of this idea was quite wrong?