Just to be clear is this what you’re saying?
- Save new fine tuned model after training
# Train and evaluate
finetuned_model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))
#Save Fine tuned model
torch.save(finetuned_model_ft.state_dict(), path/to/myModel)
- Load the new finetuned model
model_name = "resnet"
num_classes = 2
feature_extract = True
use_pretrained = True
my_fine_tuned_model, input_size = initialize_finetuned_model(model_name, num_classes, feature_extract, use_pretrained=True)
my_fine_tuned_model.load_state_dict(torch.load(path/to/myModel))
my_fine_tuned_model.eval()
I am guessing then, that the initialize_finetuned_model
function used to recreate the finetuned model for inference should be a little different than the initialize_model
function used for training because I need to call the finetuned version instead of the original resnet model right?
What I mean by this is that the initialize_model function has this line model_ft = models.segmentation.fcn_resnet101(pretrained=use_pretrained)
where it calls the original model, but when I load the fine tuned version, instead of that line I would need to have this my_fine_tuned_model = torch.load(path/to/myModel)
, correct? So the original initialize_model and initialize_finetuned_model
functio would look like below, correct?
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
# Initialize these variables which will be set in this if statement. Each of these
# variables is model specific.
model_ft = None
input_size = 0
if model_name == "resnet":
""" FCN_resnet101
"""
model_ft = models.segmentation.fcn_resnet101(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
in_chnls = model_ft.classifier[4].in_channels
model_ft.classifier[4] = nn.Conv2d(in_chnls, num_classes, 1, 1)
input_size = 768, 1024 #224
else:
print("Invalid model name, exiting...")
exit()
return model_ft, input_size
Initialized Fine Tuned Model to be used for inference
def initialize_finetuned_model(model_name, num_classes, feature_extract, use_pretrained=True):
# Initialize these variables which will be set in this if statement. Each of these
# variables is model specific.
model_ft = None
input_size = 0
if model_name == "resnet":
""" FCN_resnet101
"""
my_fine_tuned_model = torch.load(path/to/myModel)
set_parameter_requires_grad(my_fine_tuned_model, feature_extract)s
in_chnls = my_fine_tuned_model.classifier[4].in_channels
my_fine_tuned_model.classifier[4] = nn.Conv2d(in_chnls, num_classes, 1, 1)
input_size = 768, 1024 #224
else:
print("Invalid model name, exiting...")
exit()
return my_fine_tuned_model, input_size