KeyError: 'state_dict'

Hello, everyone,

I trained a resnet34 model using mainModel.py on my image sets. However, when I tried to evalue its acurracy, I met the following problem.

I first create a model
model = models.dict’resnet34’

Then I loaded the model
checkpoint=torch.load(’./model_best10000.pth.tar’)
# create new OrderedDict that does not contain module.
from collections import OrderedDict
new_checkpoint = OrderedDict()
for k, v in checkpoint.items():
name = k[7:] # remove module.
new_checkpoint[name] = v

load params

model.load_state_dict(new_checkpoint[‘state_dict’])

However I got the following error message
Traceback (most recent call last):
File “testAccuracy.py”, line 94, in
model.load_state_dict(new_checkpoint[‘state_dict’])
KeyError: ‘state_dict’

Anyone can tell me what is the problem and how to fix it?

Thanks in advance!

Nam

Could you post the code to save the checkpoint?
Also, could you run

for key in new_checkpoint:
    print(key)

It seems, new_checkpoint was not saved as the dict you expect.

Problems solved. Thank you all who replied.

The revision in the following code solved the problem

create new OrderedDict that does not contain module.

from collections import OrderedDict
new_checkpoint = OrderedDict()
for k, v in checkpoint[‘state_dict’].items():
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
name = k[7:] # remove module.
new_checkpoint[name] = v

I also met the following problem
TypeError: FloatClassNLLCriterion_updateOutput received an invalid combination of arguments - got (int, torch.FloatTensor, torch.cuda.LongTensor, torch.FloatTensor, bool, NoneType, torch.FloatTensor), but expected (int state, torch.FloatTensor input, torch.LongTensor target, torch.FloatTensor output, bool sizeAverage, [torch.FloatTensor weights or None], torch.FloatTensor total_weight)

solved the problem by adding the following code
model = torch.nn.DataParallel(model).cuda()

hope this post will be helpful for newbies like me. many thanks!

nam

Hey guys I tried this. anybody can identify whats the error??

model.to(device)

modelCheckpoint = torch.load("…/models/model_final_new_custom.pt")
for key in modelCheckpoint:
print(key)
model.load_state_dict(modelCheckpoint[‘state_dict’].items())
model = model.module.Net.features
model.eval()
weights = list(model.parameters())

the output and error i got:
conv1.weight
conv1.bias
conv2.weight
conv2.bias
bn2.weight
bn2.bias
bn2.running_mean
bn2.running_var
bn2.num_batches_tracked
conv3.weight
conv3.bias
bn3.weight
bn3.bias
bn3.running_mean
bn3.running_var
bn3.num_batches_tracked
conv4.weight
conv4.bias
bn4.weight
bn4.bias
bn4.running_mean
bn4.running_var
bn4.num_batches_tracked
fc1.weight
fc1.bias
fbn1.weight
fbn1.bias
fbn1.running_mean
fbn1.running_var
fbn1.num_batches_tracked
fc2.weight
fc2.bias
fbn2.weight
fbn2.bias
fbn2.running_mean
fbn2.running_var
fbn2.num_batches_tracked
fc3.weight
fc3.bias


KeyError Traceback (most recent call last)
in
8 for key in modelCheckpoint:
9 print(key)
—> 10 model.load_state_dict(modelCheckpoint[‘state_dict’].items())
11 model = model.module.Net.features
12 model.eval()

KeyError: ‘state_dict’

modelCheckpoint does not contain the key state_dict, which raises the error.
I guess modelCheckpoint is already the state_dict, so you might just try to use:

model.load_state_dict(modelCheckpoint)
2 Likes

Thank you! You were right