Renet101 modifying conv layer to accept a concatenated tensor

Hi,

I’ve loaded a resnet101 model that is pretrained. I have locked the gradients for layers 1 to 3

I have successfully been training it for my task

I would like to try enhance its performance through a VQ-VAE and extract an embed layer and concatenate it inside layer 4 during the forward pass.

class ResNet101(nn.Module):
    def __init__(self, num_classes, weights=ResNet101_Weights.IMAGENET1K_V2,
                 dropout_rate=0.5):
        super(ResNet101, self).__init__()

        self.resnet = models.resnet101(weights=weights)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Freeze all other parameters
        for param in self.resnet.parameters():
            param.requires_grad = False

        # Unfreeze Layer 4
        for param in self.resnet.layer4.parameters():
            param.requires_grad = True

        self.resnet.layer4[0].conv1 = nn.Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self.device)  # Move the conv1 weight to GPU

        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Dropout(dropout_rate),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(),
            nn.Dropout(dropout_rate),
        )

        self.skip_connection = nn.Linear(num_features, 512)

        self.fc_final = nn.Linear(1024, num_classes)

    '''
        # x is torch.Size([16, 1024, 14, 14])
        # z is torch.Size([16, 512, 14, 14])
        # x is now torch.Size([16, 1536, 14, 14])
    '''
    def forward(self, x, z):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)

        # Concatenate z with the input
        x = torch.cat((x, z), dim=1).to(self.device)  # Move x to GPU

        x = self.resnet.layer4(x)

        x_avg = self.resnet.avgpool(x)
        x_avg = torch.flatten(x_avg, 1)

        x_fcl = self.resnet.fc(x_avg)
        x_skip = self.skip_connection(x_avg)
        x_fcl_skip = torch.cat((x_fcl, x_skip), dim=1)

        x_final = self.fc_final(x_fcl_skip)

        return x_final

At the moment, it is not working because:

RuntimeError                              Traceback (most recent call last)
Cell In[42], line 31
     28 classifier_optimizer.zero_grad()
     30 # Forward pass
---> 31 vq_vae_output, classifier_output = vq_vae_classifier(images)
     33 # Compute the VQ-VAE loss, classifier loss and total loss
     34 vq_vae_loss, classifier_loss, total_loss = criterion(vq_vae_output, classifier_output, images, statuses)

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[29], line 15, in VQVAEClassifier.forward(self, images)
     11 z = self.vq_vae.encoder(images).to(device)
     13 images_normalized = normalize_batch(images, IMG_MEAN, IMG_STD)
---> 15 predicted_statuses = self.classifier(images_normalized, z)
     17 return ( (images_reconstructed, commitment_loss, codebook_loss, perplexity), predicted_statuses )

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\final\deep_learning\resnet101\resnet101.py:64, in ResNet101.forward(self, x, z)
     61 # Concatenate z with the input
     62 x = torch.cat((x, z), dim=1)
---> 64 x = self.resnet.layer4(x)
     66 x_avg = self.resnet.avgpool(x)
     67 x_avg = torch.flatten(x_avg, 1)

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torchvision\models\resnet.py:158, in Bottleneck.forward(self, x)
    155 out = self.bn3(out)
    157 if self.downsample is not None:
--> 158     identity = self.downsample(x)
    160 out += identity
    161 out = self.relu(out)

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\conv.py:463, in Conv2d.forward(self, input)
    462 def forward(self, input: Tensor) -> Tensor:
--> 463     return self._conv_forward(input, self.weight, self.bias)

File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\conv.py:459, in Conv2d._conv_forward(self, input, weight, bias)
    455 if self.padding_mode != 'zeros':
    456     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    457                     weight, bias, self.stride,
    458                     _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
    460                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [2048, 1024, 1, 1], expected input[16, 1536, 14, 14] to have 1024 channels, but got 1536 channels instead

I’m not sure how to resolve this? Is it even possible for me to do

The error is raised in self.resnet.layer4 since you are concatenating x with z increasing the channel dimension from the original 1024 to 1536.
You would thus need to replace the conv layer in self.resnet.layer4 with a newly initialized one accepting 1536 channels and train it.

Hi @ptrblck, thank you for your reply.

I assumed I was with the line above?

Hi @ptrblck I managed to get it working:


import torch

import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet101_Weights


class ResNet101(nn.Module):
    def __init__(self, num_classes, weights=ResNet101_Weights.IMAGENET1K_V2,
                 dropout_rate=0.5):
        super(ResNet101, self).__init__()
        self.resnet = models.resnet101(weights=weights)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Freeze all other parameters
        for param in self.resnet.parameters():
            param.requires_grad = False

        # Unfreeze Layer 4
        for param in self.resnet.layer4.parameters():
            param.requires_grad = True

        # Modify the convolutional layer in Layer 4
        self.resnet.layer4[0].conv1 = nn.Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)

        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Dropout(dropout_rate),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(),
            nn.Dropout(dropout_rate),
        )

        self.skip_connection = nn.Linear(num_features, 512)

        self.fc_final = nn.Linear(1024, num_classes)

    def forward(self, x, z=None):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)

        if z is not None:
            # Concatenate x and z along the channel dimension
            x_concat = torch.cat((x, z), dim=1).to(self.device)  # Move x to GPU
        else:
            x_concat = x

        x_concat = self.resnet.layer4[0].conv1(x_concat)
        x_concat = self.resnet.layer4[0].bn1(x_concat)

        x_concat = self.resnet.layer4[0].conv2(x_concat)
        x_concat = self.resnet.layer4[0].bn2(x_concat)

        x_concat = self.resnet.layer4[0].conv3(x_concat)
        x_concat = self.resnet.layer4[0].bn3(x_concat)

        x_concat = self.resnet.layer4[0].relu(x_concat)

        x_concat = self.resnet.layer4[0].downsample(x)  # Apply downsample to the output of layer 3

        x_concat = self.resnet.layer4[1](x_concat)
        x_concat = self.resnet.layer4[2](x_concat)

        x_avg = self.resnet.avgpool(x_concat)
        x_avg = torch.flatten(x_avg, 1)

        x_fcl = self.resnet.fc(x_avg)
        x_skip = self.skip_connection(x_avg)
        x_fcl_skip = torch.cat((x_fcl, x_skip), dim=1)

        x_final = self.fc_final(x_fcl_skip)

        return x_final

However, I am a novice, I need to be sure that this code is correct, i.e., it still maintains the original resnet101 forward pass with only the modification that I intended.

I edited it slightly, i.e., both x and z are [16, 1024, 14, 14]

Based on here

I would argue the correct forward statement is:

    def forward(self, x, z=None):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)

        if z is not None:
            # Concatenate x and z along the channel dimension
            x_concat = torch.cat((x, z), dim=1).to(self.device)  # Move x to GPU
        else:
            x_concat = x

        x_concat = self.resnet.layer4[0].conv1(x_concat)
        x_concat = self.resnet.layer4[0].bn1(x_concat)
        x_concat = self.resnet.layer4[0].relu(x_concat)

        x_concat = self.resnet.layer4[0].conv2(x_concat)
        x_concat = self.resnet.layer4[0].bn2(x_concat)
        x_concat = self.resnet.layer4[0].relu(x_concat)

        x_concat = self.resnet.layer4[0].conv3(x_concat)
        x_concat = self.resnet.layer4[0].bn3(x_concat)

        x = self.resnet.layer4[0].downsample(x)  # Apply downsample to the output of layer 3
        x_concat += x
        x_concat = self.resnet.layer4[0].relu(x_concat)

        x_concat = self.resnet.layer4[1](x_concat)
        x_concat = self.resnet.layer4[2](x_concat)

        x_avg = self.resnet.avgpool(x_concat)
        x_avg = torch.flatten(x_avg, 1)

        x_fcl = self.resnet.fc(x_avg)
        x_skip = self.skip_connection(x_avg)
        x_fcl_skip = torch.cat((x_fcl, x_skip), dim=1)

        x_final = self.fc_final(x_fcl_skip)

        return x_final