Copying weights of the model to another model

Hello everyone, I have written down VGG16_bn network and I want to use the the weights of vgg16_bn from torchvision .models.vgg16_bn. How can I copy the weights of the vgg16bn from torchvision to my custom written vgg16bn. These two models have identical number of params.

The way I tried is:

params1 = model1.named_parameters()
params2 = model2.named_parameters()

dict_params2 = dict(params2)

for name1, param1 in params1:
    if name1 in dict_params2:
        dict_params2[name1].data.copy_(param1.data)

but did not help. Because these models have different keys in their state_dict. Is there any other ways to deal this problem?

Try 2:

def copy_weights(model1, model2):
    model1.eval()
    model2.eval()
    params1 = model1.parameters()
    params2 = model2.parameters()

    with torch.no_grad():
        for param1, param2 in zip(params1, params2):
            param2.data.copy_(param1.data)

this one also couldn’t help me out.

Custom written VGG16bn as follows:

def _init_weights(self):
    """ Standard weight initializer """
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)


class Conv(nn.Module):
    """ Standard convolution module """

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int = 1,
            padding: int = 0,
            dilation: int = 1,
            groups: int = 1,
            norm: bool = False,
            act: str = None,
            bias: bool = True,
    ) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding,
                              dilation=dilation,
                              groups=groups,
                              bias=bias)
        self.norm = nn.BatchNorm2d(num_features=out_channels) if norm else nn.Identity()
        self.act = nn.ReLU(inplace=True) if act == 'relu' else nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.act(self.norm(self.conv(x)))


class VGG16(nn.Module):
    """ VGG16 with batch normalization """

    def __init__(self, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5) -> None:
        super().__init__()

        # p1/2
        self.p1 = nn.Sequential(
            Conv(in_channels=3, out_channels=64, kernel_size=3, padding=1, norm=True, act='relu'),
            Conv(in_channels=64, out_channels=64, kernel_size=3, padding=1, norm=True, act='relu'),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # p2/4
        self.p2 = nn.Sequential(
            Conv(in_channels=64, out_channels=128, kernel_size=3, padding=1, norm=True, act='relu'),
            Conv(in_channels=128, out_channels=128, kernel_size=3, padding=1, norm=True, act='relu'),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # p3/8
        self.p3 = nn.Sequential(
            Conv(in_channels=128, out_channels=256, kernel_size=3, padding=1, norm=True, act='relu'),
            Conv(in_channels=256, out_channels=256, kernel_size=3, padding=1, norm=True, act='relu'),
            Conv(in_channels=256, out_channels=256, kernel_size=3, padding=1, norm=True, act='relu'),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # p4/16
        self.p4 = nn.Sequential(
            Conv(in_channels=256, out_channels=512, kernel_size=3, padding=1, norm=True, act='relu'),
            Conv(in_channels=512, out_channels=512, kernel_size=3, padding=1, norm=True, act='relu'),
            Conv(in_channels=512, out_channels=512, kernel_size=3, padding=1, norm=True, act='relu'),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        # p5/32
        self.p5 = nn.Sequential(
            Conv(in_channels=512, out_channels=512, kernel_size=3, padding=1, norm=True, act='relu'),
            Conv(in_channels=512, out_channels=512, kernel_size=3, padding=1, norm=True, act='relu'),
            Conv(in_channels=512, out_channels=512, kernel_size=3, padding=1, norm=True, act='relu'),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.pool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(in_features=512 * 7 * 7, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=4096, out_features=num_classes),
        )

        if init_weights:
            _init_weights(self)

    def forward(self, x: torch.Tensor) -> tuple:
        p1 = self.p1(x)
        p2 = self.p2(p1)
        p3 = self.p3(p2)
        p4 = self.p4(p3)
        p5 = self.p5(p4)
        x = self.pool(p5)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x

One proper approach would be to remap the keys from the pretrained state_dict to your current keys using a custom mapping.