Initialization of network using specific (pre-trained) parameters of VGG16

Hi all,

I am new to PyTorch (have some good experience in Theano/Lasagne), and I am trying to build an SSD-like architecture.

I define the following class (apologies for its length):

class SSD(nn.Module):

    def __init__(self, init_weights=True):
        super(SSD, self).__init__()

        # ===================================[ G_1 ]================================== #
        self.g1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(inplace=True))

        # ===================================[ G_2 ]================================== #
        self.g2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True))

        # ===================================[ G_3 ]================================== #
        self.g3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True))

        # ===================================[ G_4 ]================================== #
        self.g4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True))

        # ===================================[ G_5 ]================================= #
        self.g5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                                nn.ReLU(inplace=True))

        # ===================================[ G_6 ]================================= #
        self.g6 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
                                nn.ReLU(inplace=True))

        # ===================================[ G_7 ]================================= #
        self.g7 = nn.Sequential(nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
                                nn.ReLU(inplace=True))

        # ===================================[ G_8 ]================================= #
        self.g8 = nn.Sequential(nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
                                nn.ReLU(inplace=True))

        # ===================================[ G_9 ]================================= #
        self.g9 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0),
                                nn.ReLU(inplace=True))

        # ===================================[ G_10 ]================================= #
        self.g10 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0),
                                 nn.ReLU(inplace=True),
                                 nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0),
                                 nn.ReLU(inplace=True))

        # =========================[ Face detection scaling 1 ]======================= #
        self.face_det_s1 = nn.Sequential(nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1), nn.Sigmoid())
        self.face_bb_s1 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)

        # =========================[ Face detection scaling 2 ]======================= #
        self.face_det_s2 = nn.Sequential(nn.Conv2d(1024, 1, kernel_size=3, stride=1, padding=1), nn.Sigmoid())
        self.face_bb_s2 = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)

        # =========================[ Face detection scaling 3 ]======================= #
        self.face_det_s3 = nn.Sequential(nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1), nn.Sigmoid())
        self.face_bb_s3 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)

        # =========================[ Face detection scaling 4 ]======================= #
        self.face_det_s4 = nn.Sequential(nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1), nn.Sigmoid())
        self.face_bb_s4 = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)

        # =========================[ Face detection scaling 5 ]======================= #
        self.face_det_s5 = nn.Sequential(nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1), nn.Sigmoid())
        self.face_bb_s5 = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)

        # =========================[ Face detection scaling 6 ]======================= #
        self.face_det_s6 = nn.Sequential(nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1), nn.Sigmoid())
        self.face_bb_s6 = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)

        # Initialize weights
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        """

        Args:
            x:

        Returns:

        """

        x = F.max_pool2d(self.g1(x), kernel_size=2, stride=2, padding=0)

        x = F.max_pool2d(self.g2(x), kernel_size=2, stride=2, padding=0)

        x = F.max_pool2d(self.g3(x), kernel_size=2, stride=2, padding=0)

        x = x_s1 = self.g4(x)

        f_s1 = torch.cat([self.face_det_s1(x_s1), self.face_bb_s1(x_s1)], dim=1)

        x = F.max_pool2d(x, kernel_size=2, stride=2, padding=0)

        x = F.max_pool2d(self.g5(x), kernel_size=3, stride=1, padding=1)

        x = x_s2 = self.g6(x)

        f_s2 = torch.cat([self.face_det_s2(x_s2), self.face_bb_s2(x_s2)], dim=1)

        x = x_s3 = self.g7(x)

        f_s3 = torch.cat([self.face_det_s3(x_s3), self.face_bb_s3(x_s3)], dim=1)

        x = x_s4 = self.g8(x)

        f_s4 = torch.cat([self.face_det_s4(x_s4), self.face_bb_s4(x_s4)], dim=1)

        x = x_s5 = self.g9(x)

        f_s5 = torch.cat([self.face_det_s5(x_s5), self.face_bb_s5(x_s5)], dim=1)

        x_s6 = self.g10(x)

        f_s6 = torch.cat([self.face_det_s6(x_s6), self.face_bb_s6(x_s6)], dim=1)

        f_s1_ = f_s1.view(-1, f_s1.size(2) * f_s1.size(3), 5)
        f_s2_ = f_s2.view(-1, f_s2.size(2) * f_s2.size(3), 5)
        f_s3_ = f_s3.view(-1, f_s3.size(2) * f_s3.size(3), 5)
        f_s4_ = f_s4.view(-1, f_s4.size(2) * f_s4.size(3), 5)
        f_s5_ = f_s5.view(-1, f_s5.size(2) * f_s5.size(3), 5)
        f_s6_ = f_s6.view(-1, f_s6.size(2) * f_s6.size(3), 5)

        f = torch.cat([f_s1_, f_s2_, f_s3_, f_s4_, f_s5_, f_s6_], dim=1)

        return f_s1, f_s2, f_s3, f_s4, f_s5, f_s6, f

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

Now, what I want is to pre-load the weights of the components self.g1 to self.g5 with the weights of the corresponding convolutional layers of the VGG16 (those found in ‘https://download.pytorch.org/models/vgg16-397923af.pth’). The above components resemble exactly the convolutional layers of VGG16, but I don’t know how to pre-load them (and also set them non-learnable during training).

For this, I guess that I need to change the self._initialize_weights() function so as to initialize everything esle but not the weights of self.g1 to self.g5. After that (I think outside the class), I need to load the weights of VGG16 to the rest g1-g6 part of my network.

Could you please help me on that?

Many thanks :slight_smile:

2 Likes

You could create a method to load your weights.
I created a small example for you.
You could definitely write the code in a more compact way, so this should be a starter only. :wink:

class SSD(nn.Module):
    def __init__(self, init_weights=True):
        super(SSD, self).__init__()

        # ===================================[ G_1 ]================================== #
        self.g1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(inplace=True),
                                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.g1(x)
        return x


vgg = vgg16(pretrained=False)
model = SSD()

model.g1[0].weight.data.copy_(vgg.features[0].weight.data)
model.g1[0].bias.data.copy_(vgg.features[0].bias.data)
model.g1[2].weight.data.copy_(vgg.features[2].weight.data)
model.g1[2].bias.data.copy_(vgg.features[2].bias.data)
2 Likes

Hi @ptrblck,

thank you very much for your quick (and very helpful) response.

Just a quick question; I assume that you use the pretrained vgg16 model from import torchvision.models as models, right? If so, you load it as vgg = models.vgg16(pretrained=False), but since I need the pre-trained model, shouldn’t I use pretrained=True (btw, where is this pretrained? On ImageNet for 1000 classes maybe?).

Yes, you are right. I just loaded it with pretrained=False to give you a quick example and to skip the downloading of the weights. You should definitely set it to True.

Yes, all models are pre-trained on ImageNet.

Brilliant! Thank you very much! A final question, if I may. How could I set these particular layers as non-trainable (exclude them from the optimization process)?

Thanks again :slight_smile:

You could set the required_grad property to False for the freezed layers:

model.g1[2].weight.requires_grad = False
...

Based on your code you could also set it in the loop where you assign the pre-trained weights to the parameter.
Also, you could just pass the trainable parameters to the optimizer:

optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

I see, thank you very much!

Hi ptrblck,

Why use copy_? Can’t we directly assign with =?

If you want to manipulate the parameters inplace, use copy_ and wrap it into a with torch.no_grad(): guard. Alternatively, if you want to create new parameters and want to pass them to an optimizer afterwards, assign new parameters.

We are trying to implement segnet which has encoder weights from vgg 16.
So, we created new cnn layers and are loading the weights of encoder from corresponding vgg layer. Vgg16 is used only for weight initialisation. The encoder weights would be trainable after this. Vgg is not part of forward call otherwise.