Hi,
I’ve loaded a resnet101 model that is pretrained. I have locked the gradients for layers 1 to 3
I have successfully been training it for my task
I would like to try enhance its performance through a VQ-VAE and extract an embed layer and concatenate it inside layer 4 during the forward pass.
class ResNet101(nn.Module):
def __init__(self, num_classes, weights=ResNet101_Weights.IMAGENET1K_V2,
dropout_rate=0.5):
super(ResNet101, self).__init__()
self.resnet = models.resnet101(weights=weights)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Freeze all other parameters
for param in self.resnet.parameters():
param.requires_grad = False
# Unfreeze Layer 4
for param in self.resnet.layer4.parameters():
param.requires_grad = True
self.resnet.layer4[0].conv1 = nn.Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).to(self.device) # Move the conv1 weight to GPU
num_features = self.resnet.fc.in_features
self.resnet.fc = nn.Sequential(
nn.Linear(num_features, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(),
nn.Dropout(dropout_rate),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(),
nn.Dropout(dropout_rate),
)
self.skip_connection = nn.Linear(num_features, 512)
self.fc_final = nn.Linear(1024, num_classes)
'''
# x is torch.Size([16, 1024, 14, 14])
# z is torch.Size([16, 512, 14, 14])
# x is now torch.Size([16, 1536, 14, 14])
'''
def forward(self, x, z):
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x = self.resnet.layer1(x)
x = self.resnet.layer2(x)
x = self.resnet.layer3(x)
# Concatenate z with the input
x = torch.cat((x, z), dim=1).to(self.device) # Move x to GPU
x = self.resnet.layer4(x)
x_avg = self.resnet.avgpool(x)
x_avg = torch.flatten(x_avg, 1)
x_fcl = self.resnet.fc(x_avg)
x_skip = self.skip_connection(x_avg)
x_fcl_skip = torch.cat((x_fcl, x_skip), dim=1)
x_final = self.fc_final(x_fcl_skip)
return x_final
At the moment, it is not working because:
RuntimeError Traceback (most recent call last)
Cell In[42], line 31
28 classifier_optimizer.zero_grad()
30 # Forward pass
---> 31 vq_vae_output, classifier_output = vq_vae_classifier(images)
33 # Compute the VQ-VAE loss, classifier loss and total loss
34 vq_vae_loss, classifier_loss, total_loss = criterion(vq_vae_output, classifier_output, images, statuses)
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[29], line 15, in VQVAEClassifier.forward(self, images)
11 z = self.vq_vae.encoder(images).to(device)
13 images_normalized = normalize_batch(images, IMG_MEAN, IMG_STD)
---> 15 predicted_statuses = self.classifier(images_normalized, z)
17 return ( (images_reconstructed, commitment_loss, codebook_loss, perplexity), predicted_statuses )
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\final\deep_learning\resnet101\resnet101.py:64, in ResNet101.forward(self, x, z)
61 # Concatenate z with the input
62 x = torch.cat((x, z), dim=1)
---> 64 x = self.resnet.layer4(x)
66 x_avg = self.resnet.avgpool(x)
67 x_avg = torch.flatten(x_avg, 1)
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torchvision\models\resnet.py:158, in Bottleneck.forward(self, x)
155 out = self.bn3(out)
157 if self.downsample is not None:
--> 158 identity = self.downsample(x)
160 out += identity
161 out = self.relu(out)
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\conv.py:463, in Conv2d.forward(self, input)
462 def forward(self, input: Tensor) -> Tensor:
--> 463 return self._conv_forward(input, self.weight, self.bias)
File c:\Users\Tanvi\Desktop\comp5200m-msc-project\venv\Lib\site-packages\torch\nn\modules\conv.py:459, in Conv2d._conv_forward(self, input, weight, bias)
455 if self.padding_mode != 'zeros':
456 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
457 weight, bias, self.stride,
458 _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
460 self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [2048, 1024, 1, 1], expected input[16, 1536, 14, 14] to have 1024 channels, but got 1536 channels instead
I’m not sure how to resolve this? Is it even possible for me to do