Use pre trained resnet101 for regression data

Hello everyone!
I wanted to use the resnet101 for a regression like problem. So, the input of the network is a image with an associated target (a number), and I want to get an output by training a model like regression.

What I am doing is adding a linear layer in the end of the resnet101 so the output if a single value.

model = models.resnet101(pretrained=True)

num_ftrs = model.fc.in_features

model.fc = nn.Linear(num_ftrs, 1)

Does this make sense? How would you suggest to use resnet or other pre trained cnn for continuous data?

I have the same problem about how to use a pre-trained model to my own work.

my main question is basically how to use the pre trained resnet for continous data instead of categorical data! For categorical data you do what I put above and the linear layer is (num_ftrs, number of classes). Then, when extracting the prediction in the test phase you can use a softmax to get the probability of each class!

This solution looks correct.

model = models.resnet101(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)

Ensure you change you use a loss made for regression problems such as torch.nn.MSELoss.

You can use a network pre-trained on a classification problem and transfer it to a regression problem. You might need to unfreeze the last blocks to adapt it to your application.

The first blocks from ResNet or another net, will detect features such as edges, forms, which is mostly invariant when changing datasets.

Hope this helps!

Thank you !! How can I unfreeze the last blacks ?

Something like:

for param in model.parameters():
     param.requires_grad = False # False when you freeze the layer / True when you want to train it

Thank you! And other thing, why I need to unfreeze the last blocks? I think I am missing something here, like, theoretically I don’t understand very well why it is done like this.

This tutorial gives both the explanations and the example with ResNet:

Thank you so much for the help!! :blush: