I am using the fastai library (fast.ai) to train an image classifier. The model created by fastai is actually a pytorch model.
Now, I want to use this model from pytorch for inference. Here is my code so far:
the_model = torch.load("./torch_model_v1")
the_model.eval() # shows the entire network architecture
Based on the example shown here: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#sphx-glr-beginner-data-loading-tutorial-py, I understand that I need to write my own data loading class which will override some of the functions in the Dataset class. But what is not clear to me is the transformations that I need to apply at test time? In particular, how do I normalize the images at test time?
Another question: is my approach of saving and loading the model in pytorch fine? I read in a pytorch tutorial that the approach that I have used is not recommended. The reason is not clear though.
It is not recommended to save the entire model (architecture and weights) the way that you did because that method will not work if you try to load the model in a different project. For example, if you try to send your model file
./torch_model_v1 to me and I try to load it with
torch.load("./torch_model_v1") I will get an error because it’s likely my project won’t have the exact same directory structure as your project.
Instead you should save only the model weights (state dict), define the architecture in code, then load the weights into the new models state dict.
model = Model() # the model should be defined with the same code you used to create the trained model
state_dict = torch.load( "./torch_model_v1.pt")
@austin Is it necessary to save optimizer state dict for inference as listed in pytorch tutorial " When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model’s state_dict . It is important to also save the optimizer’s state_dict , as this contains buffers and parameters that are updated as the model trains." ? or this just it is important for completing the training?
It’s important to resume the training. If you just want to use the model for inference, loading the
state_dict and setting the model to
eval() should be enough.