Trying to load a pre-trained vision model

Hello,

I am trying to load a pre-trained vision model (resnet18) from the pth file (rensnet18-xxxx.pth) on my local disk. I see this only contains the state dictionary.

I understand to load this into a model I have to initialize the Model class and then do a load_state_dict on an instance of the model class as specified here: https://pytorch.org/docs/master/notes/serialization.html

However what parameters should I use to initialize the Model class. How can I find them out ?

Thanks
Anirudh

You can just create an instance of the model and use the default initialization, since you’ll load the state_dict anyway and replace all parameters and buffers with the pretrained ones.