How to run trained model?

I’m new to pytorch. I see lots of tutorials that focus on how to use the API to train, but my question is, once I have a trained model, what is the definitive way to execute it on some data, such as picture classification?

Do you have a specific use case in mind?
The easiest way would be to store the trained model and load it in another script for inference.
Have a look at the Serialization semantics for more information.
Your inference code might load the data directly from your hard drive / network storage or you could create a web app using e.g. Flask and feed the data to your model.

You can also deploy your model to Caffe2 using ONNX. Here is a nice tutorial.

1 Like

Well, right now I’m looking at Module.eval as being how it’s done. But basically, any scenario where I would feed some input to a trained model and get a result - any result, any input - back should be suitable enough to teach me where or how the API is used. I’m not so sure there’s more than one method of using a trained model at all, but in all the example code I’ve seen so far, I don’t see eval, and I see a lot of stuff about training.

For inference / evaluation you should set model.eval() to set all nn.Modules to evaluation.
This will change the behavior of some modules, e.g. nn.Dropout which won’t drop any features anymore and nn.BatchNorm which will use the running estimates of mean and std in the default settings.
That’s basically how you would perform the inference for a classification of images:

# Train your model
...
# After training
model.eval()
data = torch.randn(1, 3, 24, 24) # Load your data here, this is just dummy data
output = model(data)
prediction = torch.argmax(output)
2 Likes

That sequence right there was what I was looking for! I’m pretty sure that argmax isn’t the only way to represent the result of some model evaluation in the whole machine learning ecosystem, but all I need to know is what the model is doing for graph traversal on the inside.

I’m familiar with the serialization semantics and I’ve known about them for a while. I’m sure this will be considered a separate question - I see that there is a C++ version of each of load and save, torch::load and torch::save. But the C/++ and python respective functions don’t use the same format, and the python version pickles the code so that it’s not easily accessible from C++. 1) Is there an easy way I can load and traverse over the model in C++, and 2) will there be any future effort to try and use something a little more portable, like protobuf, for saving models?

Sure, this was just an example for a typical classification use case.

Production-ready PyTorch is planned to be released later this summer / autumn.
Have a look at the road to PyTorch1.0 for more information. While it’s not released, you could stick to ONNX and export the models in Caffe2.

The example is very good, and I’m happy with it provided I can dig that traversal I need out.

Does the model traversal get performed at model(data) or argmax?

Also, what is the instance type of model? Because I can’t just do torch.load(xyz), that will give me an OrderedDict.

The forward pass will be performed at model(data).
model itself is an instance of nn.Module (source code).

Ah ok. Thanks!

This reply requires at least 20 characters.

Hey wait, in this example here, you can see at line 207 there is a construction of the module sub-class RNN into the variable rnn. Then, at line 221, there is a call of rnn as though it is a function. What does line 234 resolve to?

#first this:
rnn = RNN(n_letters, n_hidden, n_categories)
#then it’s this:
output, next_hidden = rnn(input, hidden)

If call of a model performs it’s forward path.
In your example you pass input and hidden to your model, the forward pass is performed, and finally you’ll get output and next_hidden as the result.

No, I mean what function on the object RNN does the line ‘rnn(input, hidden)’ correspond to?

For example, constructors correspond to init, len corresponds to len, ect. But it looks like I’ve already constructed the object with RNN, how am I calling an object as though it is a function?

I found it! It’s __call__!