Trying to implement Transfer Learning but getting unreasonable results

Hi!
I followed the pytorch tutorial on the topic transfer learning but I am getting results that are clearly wrong. As the application I am trying to apply is different a had to make some tweaks and it would be great if someone can give me some insight.

You can see below that I am here removing the final layer of the pre-trained model instead of replacing it.

model = models.resnet50(pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))

And below you can see why I took a different approach. I need to compare the features of 3 images, produced by the pretrained model, and I do this by concatenating said features and passing it through my fully connected layers, which I will optimize.

class Net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_dim, 600)
        self.fc2 = nn.Linear(600, 600)
        self.fc3 = nn.Linear(600, output_dim)

    def forward(self, input_1, input_2, input_3, pretrained_model):
        with torch.no_grad():
            output_1 = pretrained_model(input_1.float())
            output_2 = pretrained_model(input_2.float())
            output_3 = pretrained_model(input_3.float())

        s = torch.cat((output_1, output_2, output_3), dim=1).view(output_1.shape[0], -1)
        
        s = F.relu(self.fc1(s.float()))
        s = F.relu(self.fc2(s))
        output = self.fc3(s)
        return output

The “classification” process is based on whether image 2 or image 3 is more similar to image 1: prediction [0, 1] or [1, 0]

Using the Cross entropy loss and Adam optimizer the running loss does decrease, however when a pass the validation set, the model strictly speaking predicts the same thing for all possible inputs (since the data is evenly split between [0, 1] and [1, 0] I get an accuracy of 0.5).

Please let me know if you have any idea what might be going on here. Have been debugging for a long time now.

Since you don’t want to train the resnet, did you also call model.eval() before using it (or just for the validation)?

Did you process the validation images in a different way than the training images?

Hello again!
Both the training and the validation images are processed using the same functions.

Your first point, however, might be of concern. I did not use model.eval() but I thought passing my data using “with torch.no_grad():” would do the same job. Can it be that this is not the case?

So to be clear, instead of this in my forward pass:

with torch.no_grad():
            output_1 = pretrained_model(input_1.float())
            output_2 = pretrained_model(input_2.float())
            output_3 = pretrained_model(input_3.float())

you suggest I do this when downloading the pretrained model?

model = models.resnet18(pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))
model.eval()
model = model.to(device)

torch.no_grad() will avoid storing the intermediate tensors during validation and testing and will thus save memory.
model.eval() will change the behavior of some layers. E.g. it will disable dropout layers and use the running stats in batchnorm layers instead of the batch statistics.

If you didn’t use model.eval() during your training, the batchnorm statistics would have been updated (which might be a valid use case). However, this shouldn’t be used during evaluation or testing, especially if your batch size is small, as this could knock out the running stats.

As an update, this solved the problem! Thank you