Added extra layers to ResNet18 model, getting unexpected results when using forward hook

I want to use contrastive loss for image classification. Untill now, I was training a ResNet18 (torchvision.models.resnet18()), and it was doing ok. I want to use the contrastive triplet loss method to fine-tune so I created a network class like so:

class Triplet_Net(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.Feature_Extractor = model
        num_filters = self.Feature_Extractor.fc.in_features
        num_classes = 7
        self.Feature_Extractor.fc = nn.Linear(num_filters,num_classes)  # the output I'm gonna need
        self.Triplet_Loss = nn.Sequential(nn.Linear(num_classes,2))
    def forward(self,x):
        features = self.Feature_Extractor(x)
        triplets = self.Triplet_Loss(features)
        return triplets

Then I loaded the weights of a previously trained pytorch ResNet18 model I had saved, and created a “Triplet_Net” with the same weights (I have confirmed that the loaded_model and triplet_model have the same weights after these lines of code).

loaded_model = torchvision.models.resnet18()
loaded_model .load_state_dict(torch.load(model_path));
triplet_model = Triplet_Net(loaded_model )
triplet_model

=====

Output:
Triplet_Net(
  (Feature_Extractor): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer2): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer3): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): Linear(in_features=512, out_features=7, bias=True)
  )
  (Triplet_Loss): Sequential(
    (0): Linear(in_features=7, out_features=2, bias=True)
  )
)

The problem is, when using a forward hook to get the activations of the last fc layer of the ResNet part of the Triplet_Net I get different values than the ones I get when using the same exact hook on the original ResNet. I have no idea why this is happening. No additional training was done to the triplet_model so it should behave the same as loaded_model.

Bellow is part of the code I use to get the fc layer output either with the resnet (classic==True) or with the custom (classic==False) network.

if classic==True:
    hook = model.fc.register_forward_hook(get_activation('fc'))
else:
    hook = model.Feature_Extractor.fc.register_forward_hook(get_activation('fc'))
    
with torch.inference_mode():
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        y_pred = model(X)
        fc_output = activation['fc']

hook.remove()

How large is the discrepancy? If it’s not large, could be rounding errors. Some code where we can reproduce your error would be valuable.

I don’t fully understand your use case since you are explicitly replacing self.Feature_Extractor.fc with a new nn.Linear layer in Triplet_Net.__init__.
Are you loading the state_dict of the reference model afterwards again?

1 Like

Thanks for replying. Νο, it is significant. Here is some code to reproduce it. It happens even with the weights of an untrained network which makes me think I am doing something wrong with the forward hook.

import torch
import torch.nn as nn
import torchvision

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

class Triplet_Net(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.Feature_Extractor = model
        num_filters = self.Feature_Extractor.fc.in_features
        num_classes = 7
        self.Feature_Extractor.fc = nn.Linear(num_filters,num_classes)
        self.Triplet_Loss = nn.Sequential(nn.Linear(num_classes,2))
    def forward(self,x):
        features = self.Feature_Extractor(x)
        triplets = self.Triplet_Loss(features)
        return triplets

test_img = torch.rand(3,224,224)
num_classes = 7

# My trained model
loaded_model = torchvision.models.resnet18()
loaded_model.fc = nn.Linear(512,num_classes)
hook = loaded_model.fc.register_forward_hook(get_activation('fc'))
loaded_model.eval()
y_pred = loaded_model(test_img.unsqueeze_(0))
print(activation['fc'])
hook.remove()

# The triplet model
trip_model = Triplet_Net(loaded_model)
trip_model.eval()
hook = trip_model.Feature_Extractor.fc.register_forward_hook(get_activation('fc'))
y_pred = trip_model(test_img)
print(activation['fc'])
hook.remove()

Yeah that was it, thanks a lot!

I was basically reinitializing the fc layer weights. The problem was solved when this line

self.Feature_Extractor.fc = nn.Linear(num_filters,num_classes)

was removed.

1 Like