Loading pytorch model

i am trying to load a model in pytorch
but getting following error

import torch
from models import micro_models
model = micro_models.NetworkCIFAR
model.load_state_dict(torch.load("model.pth"))

last line is showing following error
Traceback (most recent call last):
File “”, line 1, in
TypeError: load_state_dict() missing 1 required positional argument: ‘state_dict’

1 Like

Hi. Make sure, that torch.load("model.pth") actually return dict object and not something else (perhaps, None)

2 Likes
dt = torch.load("/home/ml/Desktop/workspace/nsga-net-master-trial/search-GA-BiObj-micro-20201201-132957/arch_1/model.pt")
>>> type(dt)
<class 'collections.OrderedDict'>

it returns OrderesDict

Ok. I tried to provide load_state_dict() with incompatible dict object and I get different error (“Error(s) in loading state_dict”)
It is a snippet:

model = torchvision.models.resnet18()
od = OrderedDict() 
torch.save(od, 'tmp.pt')
model.load_state_dict(od)
model.load_state_dict(torch.load('tmp.pt'))

The only way I can reproduce the error “TypeError: load_state_dict() missing 1 required positional argument: ‘state_dict’” is actually to call this function without any arguments like this model.load_state_dict()

What is your pytorch version, by the way?

1 Like

Also, I noticed that error came from line 1 of some file “”, complaining no argument is provided to load_state_dict(). While I can’t look into that file it is hard to tell any further :slight_smile:

1 Like

i resolved it. i had not created instance of class model. we need to create instance of class model and then call function

model = micro_models.NetowrkCIFAR(24, 5, 11, False, genotype)
model.load_state_dict(torch.load("model.pth"))

those are the some of the parameters we need to pass to model class to create instance
that line was missing before
Thank you for your help

2 Likes