Hi, I am new to the concept of transfer learning. I have one pretrained model with pth extension. This model trained with a lot of medical images and its purpose was segmentation. But I wanted to use the model for prediction. My purpose is to detect survival times for different patients from their MRI images. I thought that the model that can make segmentation can extract good features for survival times. So I wanted to load the pretrainedmodel.pth and delete some layers of this model and add some fully connected layers to it. First I don’t know how is it possible to get the layers of that pretrained model second I do not know how can I delete some of that pretrained model layers and add new layers to it. I am familiar with the freezing topic.
If you know the answer can you please explain with codes?
Assuming the stored file contains the trained state_dict
of the model, you could create a new object of the model, load the state_dict
, and use parts of your pretrained model in a new classification model.
Something like this should work:
# create object of pretrained model
pretrained_model = SegmentationModel(args)
# load state_dict
pretrained_model.load_state_dict(torch.load(path_to_state_dict))
# create object of new classification model
model = ClassificationModel(pretrained_model)
# in the class definition of Classification model pretrained_model's layers will be reused
I am a little confused so excuse me if I ask basic questions. What is SegmentaionModel and args in front of it?
I just could find a file on github it was segmentation.pth so I do not have any idea what is this model architecture. first I wanted to load this model second I wanted to understand what is architecture for example with something like model.summary we can find that it has 1 cnn 1 relu 1 pool 1 cnn 1 relu 1 global avg pool 1 flatten 1 fc. Then I wanted to delete 1 cnn 1 relu 1 global avg pool 1 flatten 1 fc from last layers of model and add some arbitrary layers to it. Now I have these problems:
- How can understand the architecture of model(Code)
- How can delete last layers(Code)
- How can add new layers(Code)
When you explained the answer you made two objects of two classes : SegmentationModel and
ClassificationModel. But what are these classes?