How to load pytorch model

I have saved my model using the code torch.save(the_model.state_dict(), PATH) after training

while loading, I am confused. can someone explain this code
the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH))

The first line of code creates an instance of TheModelClass given the provided arguments.
The parameters of this model will be randomly initialized as was either defined by yourself in the model implementation or just using the default reset_parameter method of the used layers.
Since you’ve stored the state_dict after training (which is the recommended way of serializing the model!), you can now load the trained parameters from the state_dict into your model instance via the_model.load_state_dict(torch.load(path_to_state_dict)).

2 Likes

it returns TheModelClass is not defined

TheModelClass is just a placeholder name for your model class name.
E.g. you can define your model as:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(1, 1)
 
    def forward(self, x):
        return self.fc1(x)

model = MyModel()
model.load_state_dict(torch.load(PATH))

In that case, you would be using MyModel as the class name instead of TheModelClass.

1 Like

what about *args and **kwargs?

These constructs are used to pass a variable amount of arguments to a class instantiation or function in Python. Have a look at this explanation for more information.

I was wondering if I could have a clarification on this. For my TheModelClass or MyModel or whatever we name the class, I assume the structure has to match the structure of the original model right?

So for example, if I am fine tuning an fcn segmentation model would the class basically mimic this structure.

So in my case:

from collections import OrderedDict

import torch
from torch import nn
from torch.nn import functional as F


class MyModel(nn.Module):
    __constants__ = ['aux_classifier']

    def __init__(self, backbone, classifier, aux_classifier=None):
        super(MyModel, self).__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.aux_classifier = aux_classifier

    def forward(self, x):
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result["out"] = x

        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result["aux"] = x

        return result
1 Like

If you are trying to implement a specific model and would like to keep the workflow, then yes, the structure should be as close as possible.
It seems you would like to write a segmentation model using _SimpleSegmentationModel as the base class. If you derive from this class as your parent class, and want to just reuse e.g. the forward function, you don’t have to copy it and the parent function will be called directly instead.

If you are writing a custom model without any base model, you are free to write it as you wish.

1 Like

You mean do something like this?

from torch import nn

from ._utils import _SimpleSegmentationModel


__all__ = ["FCN"]


class MyModel(_SimpleSegmentationModel):
    """
    Implements a Fully-Convolutional Network for semantic segmentation.
    Arguments:
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    """
    pass

I took this from this page

And do I need to have something like this too (taken from the same page)?

class FCNHead(nn.Sequential):
    def __init__(self, in_channels, channels):
        inter_channels = in_channels // 4
        layers = [
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, channels, 1)
        ]

        super(FCNHead, self).__init__(*layers)
1 Like

It really depends on your use case.
Could you explain what you are trying to achieve and based on this we could think about the best approach?

E.g. if you derive your custom model form _SimpleSegmentationModel without changing any method, you could just use the parent model directly.

I am just trying to fine tune this model models.segmentation.fcn_resnet101(pretrained=True) with my own data/images to try to get better segmentation accuracy for a specific set of images.

1 Like

In that case i think just swapping the last layer in the classifier might be sufficient:

model = models.segmentation.fcn_resnet101(pretrained=True)

nb_classes = 2
model.classifier[4] = nn.Conv2d(
    in_channels=512,
    out_channels=nb_classes,
    kernel_size=1,
    stride=1)

Let me know, if that would work for you or if you need to dig into the model.

1 Like

I’m confused. How would I tie that to the MyModel class?

You wouldn’t need to wrap it in another class or why do you want to use a custom nn.Module on top?

From this tutorial I thought that I had to load the model this way after saving the model.

Save:

torch.save(model.state_dict(), PATH)

Load:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

TheModelClass is just a placeholder for any model.
E.g. you could have written a custom model named MyModel, which you could swap then with TheModelClass.

If you want to use a pretrained model, you can directly use it without wrapping into another class.
Of course, you can adapt some layers based on your use case (e.g. changing the last classification layer).

So if I understand you correctly in my case I wouldn’t need this line: model = TheModelClass(*args, **kwargs) ?

So just retrieving the model like this would be enough?

model = torch.load(PATH)
model.load_state_dict(torch.load(PATH))
model.eval()

Also, I am a little confused with what you mentioned here, because I already update/ or create a new layer in the finetuning process when I initialize the model like in the function below. So wouldn’t doing this again when retrieving the saved finetuned model be redundant?

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):

    # Initialize these variables which will be set in this if statement. Each of these

    #   variables is model specific.

    model_ft = None

    input_size = 0

    if model_name == "resnet":

        """ FCN_resnet101

        """

        model_ft = models.segmentation.fcn_resnet101(pretrained=use_pretrained)

        set_parameter_requires_grad(model_ft, feature_extract)

        in_chnls = model_ft.classifier[4].in_channels

        model_ft.classifier[4] = nn.Conv2d(in_chnls, num_classes, 1, 1)

        input_size = 750, 1000 

    else:

        print("Invalid model name, exiting...")

        exit()

    return model_ft, input_size

I then use model_ft as in input in my train_model function which returns the fine tuned model

# Initialize the model for this run

model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
# Train and evaluate

finetuned_model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

So the next step would be then to save the new finetuned_model_ft like this:

torch.save(finetuned_model_ft.state_dict(), path/to/myModel)

and then retrieve it to be used for inference like this:

my_fine_tuned_model = torch.load(path/to/myModel)
my_fine_tuned_model.load_state_dict(torch.load(path/to/myModel))
my_fine_tuned_model.eval()

Is this correct?

When restoring the model, I would recommend to recreate your model the same way you did initially, instead of calling my_fine_tuned_model = torch.load(path/to/myModel).
Besides that, your code should work fine.

I’m confused, what do you mean by the same way I did it initially? I mentioned a lot of things in this thread so I’m not sure which way you are talking about haha :smile:

Haha, sorry for the confusion.
By “initially” I meant you should initialize the model in the same way as you have initialized it before storing the state_dict.

So, if you’ve created the model via model_ft, input_size = initialize_model(...), and stored the state_dict afterwards, you should use the same approach to recreate your model and then just load the state_dict again.