How to use the weights of a Conv2D in order to initialize the weights of another Conv2D?

Hi all,

I’m interested in a very simple idea, which I’ll try to explain with a toy example. So, I want to use the weights of a 2D convolution (along with its bias terms) in order to initialize the weights (and biases) of another 2D convolution. More specifically, I want to initialize the weights of the second convolution with the element-wise squares of the weights of the first convolution.

However, I want this to mean that the shared between those two convolutions weights refer to the same variables; that is, they are essentially the same weights (they don’t just have the same values at the initialization step).

Please take a look at the following example. I would like to tell PyTorch that conv_2 should be initialized with the squares of conv_1’s weights (and with the same biases). As a result, the number of parameters of this network will be equal to the number of parameters due to conv_1.

Is this possible?

Thank you for your time.

import torch
import torch.nn as nn
import torch.nn.init as init

def weights_init(m):
    if isinstance(m, nn.Conv2d):

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)

    def forward(self, x):
        return self.conv_2(self.conv_1(x))

net = Net()

# Initialization of conv_1

x = torch.randn(1, 16, 300, 300).requires_grad_(True)
y = net(x)


If I understand correctly, you want more than just “initialize” the weigts based on the other conv, you want the weigts to be a functions of the other convs one?
In that case, they should have similar parameters, because the weight and bias tensors won’t have the same size if you have different input and output channels.

If they have the same parems, you can write a simple custom module:

from torch.nn import functional as F

class LinkedConv(nn.Module):
    def __init__(self, other_conv):
        # To prevent nn from thinking that we have a submodule with learnable parameters
        # we put it in a list (that nn won't inspect when looking for nn.Modules).
        self.other_conv = [other_conv]

    def forward(self, input):
        other_conv = self.other_conv[0]
        # I just square here but you can do anything differentiable
        my_weights  = other_conv .weight ** 2
        my_bias  = other_conv.bias ** 2
        # I also assume you have exactly the same params as the other_conv
        return F.conv2d(input, my_weights, my_bias, other_conv.stride,
                        other_conv.padding, other_conv.dilation, other_conv.groups)

You can see [here] ( the original forward method for Conv2d to see the difference.

Then, you can create your convolutions as:

self.conv_1 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
self.conv_2 = LinkedConv(self.conv_1)

If your convs won’t have the same params and you have a given algo to change the weights size, you can add some args to the LinkedConv class and have a more complex function that creates my_weights.


Hi @albanD, many thanks for your immediate response. This is pretty much exactly what I was looking for. I couldn’t had thought of the “listing” trick.

Many thanks!

1 Like