KeyError: 'state_dict'


(zhao nam) #1

Hello, everyone,

I trained a resnet34 model using mainModel.py on my image sets. However, when I tried to evalue its acurracy, I met the following problem.

I first create a model
model = models.dict’resnet34’

Then I loaded the model
checkpoint=torch.load(’./model_best10000.pth.tar’)
# create new OrderedDict that does not contain module.
from collections import OrderedDict
new_checkpoint = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] # remove module.
new_checkpoint[name] = v

load params

model.load_state_dict(new_checkpoint[‘state_dict’])

However I got the following error message
Traceback (most recent call last):
File “testAccuracy.py”, line 94, in
model.load_state_dict(new_checkpoint[‘state_dict’])
KeyError: ‘state_dict’

Anyone can tell me what is the problem and how to fix it?

Thanks in advance!

Nam


#2

Could you post the code to save the checkpoint?
Also, could you run

for key in new_checkpoint:
    print(key)

It seems, new_checkpoint was not saved as the dict you expect.


(zhao nam) #3

Problems solved. Thank you all who replied.

The revision in the following code solved the problem

create new OrderedDict that does not contain module.

from collections import OrderedDict
new_checkpoint = OrderedDict()
for k, v in checkpoint[‘state_dict’].items():
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
name = k[7:] # remove module.
new_checkpoint[name] = v

I also met the following problem
TypeError: FloatClassNLLCriterion_updateOutput received an invalid combination of arguments - got (int, torch.FloatTensor, torch.cuda.LongTensor, torch.FloatTensor, bool, NoneType, torch.FloatTensor), but expected (int state, torch.FloatTensor input, torch.LongTensor target, torch.FloatTensor output, bool sizeAverage, [torch.FloatTensor weights or None], torch.FloatTensor total_weight)

solved the problem by adding the following code
model = torch.nn.DataParallel(model).cuda()

hope this post will be helpful for newbies like me. many thanks!

nam