How to load pytorch model

Just to be clear is this what you’re saying?

  1. Save new fine tuned model after training
# Train and evaluate

finetuned_model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

#Save Fine tuned model
torch.save(finetuned_model_ft.state_dict(), path/to/myModel)
  1. Load the new finetuned model
model_name = "resnet"
num_classes = 2
feature_extract = True
use_pretrained = True

my_fine_tuned_model, input_size = initialize_finetuned_model(model_name, num_classes, feature_extract, use_pretrained=True)

my_fine_tuned_model.load_state_dict(torch.load(path/to/myModel))
my_fine_tuned_model.eval()

I am guessing then, that the initialize_finetuned_model function used to recreate the finetuned model for inference should be a little different than the initialize_model function used for training because I need to call the finetuned version instead of the original resnet model right?

What I mean by this is that the initialize_model function has this line model_ft = models.segmentation.fcn_resnet101(pretrained=use_pretrained) where it calls the original model, but when I load the fine tuned version, instead of that line I would need to have this my_fine_tuned_model = torch.load(path/to/myModel), correct? So the original initialize_model and initialize_finetuned_model functio would look like below, correct?

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):

    # Initialize these variables which will be set in this if statement. Each of these

    #   variables is model specific.

    model_ft = None

    input_size = 0

    if model_name == "resnet":

        """ FCN_resnet101

        """


        model_ft = models.segmentation.fcn_resnet101(pretrained=use_pretrained)

        set_parameter_requires_grad(model_ft, feature_extract)

        in_chnls = model_ft.classifier[4].in_channels


        model_ft.classifier[4] = nn.Conv2d(in_chnls, num_classes, 1, 1)

        input_size = 768, 1024 #224

    else:

        print("Invalid model name, exiting...")

        exit()

    return model_ft, input_size

Initialized Fine Tuned Model to be used for inference

def initialize_finetuned_model(model_name, num_classes, feature_extract, use_pretrained=True):

    # Initialize these variables which will be set in this if statement. Each of these

    #   variables is model specific.

    model_ft = None

    input_size = 0

    if model_name == "resnet":

        """ FCN_resnet101

        """

        my_fine_tuned_model = torch.load(path/to/myModel)

        set_parameter_requires_grad(my_fine_tuned_model, feature_extract)s

        in_chnls = my_fine_tuned_model.classifier[4].in_channels

        my_fine_tuned_model.classifier[4] = nn.Conv2d(in_chnls, num_classes, 1, 1)

        input_size = 768, 1024 #224

    else:

        print("Invalid model name, exiting...")

        exit()

    return my_fine_tuned_model, input_size

You could use initialize_model in both use cases (training and inference).
Although you would load the pretrained model, you would load your finetuned state_dict at the end.
I would argue against using torch.load('path_to_model') as it has some drawbacks as explained in the Serialization notes.

Thanks! I didn’t realize I could use the finetuned state_dict with the original pre-trained model. One question, in the initialize_model I call the set_parameter_requires_grad function (code below taken from this tutorial). Do I still need to call this function when I use initialize_model for inference?

def set_parameter_requires_grad(model, feature_extracting):

    if feature_extracting:

        for param in model.parameters():

            param.requires_grad = False

If you are running your inference code in a with torch.no_grad() block, you wouldn’t have to set the requires_grad attribute to False, but it should also work if you do. :slight_smile:

1 Like

When i use model = TheModelClass(*args, **kwargs), it always pops up an error: NameError: name ‘args’ is not defined. How to solve it? Thanks so much!

This line of code is just an example and you should replace it with your actual model class definition as well as with the expected input arguments.