[solved] KeyError: 'unexpected key "module.encoder.embedding.weight" in state_dict'

I am getting the following error while trying to load a saved model.

KeyError: 'unexpected key "module.encoder.embedding.weight" in state_dict'

This is the function I am using to load a saved model. (as suggested in http://pytorch.org/docs/notes/serialization.html#recommend-saving-models)

def load_model_states(model, tag):
    """Load a previously saved model states."""
    filename = os.path.join(args.save_path, tag)
    with open(filename, 'rb') as f:
        model.load_state_dict(torch.load(f))

The model is a sequence-to-sequence network whose init function (constructor) is given below.

def __init__(self, dictionary, embedding_index, max_sent_length, args):
    """"Constructor of the class."""
    super(Sequence2Sequence, self).__init__()
    self.dictionary = dictionary
    self.embedding_index = embedding_index
    self.config = args
    self.encoder = Encoder(len(self.dictionary), self.config)
    self.decoder = AttentionDecoder(len(self.dictionary), max_sent_length, self.config)
    self.criterion = nn.NLLLoss()  # Negative log-likelihood loss

    # Initializing the weight parameters for the embedding layer in the encoder.
    self.encoder.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize)

When I print the model (sequence-to-sequence network), I get the following.

Sequence2Sequence (
  (encoder): Encoder (
    (drop): Dropout (p = 0.25)
    (embedding): Embedding(43723, 300)
    (rnn): LSTM(300, 300, batch_first=True, dropout=0.25)
  )
  (decoder): AttentionDecoder (
    (embedding): Embedding(43723, 300)
    (attn): Linear (600 -> 12)
    (attn_combine): Linear (600 -> 300)
    (drop): Dropout (p = 0.25)
    (out): Linear (300 -> 43723)
    (rnn): LSTM(300, 300, batch_first=True, dropout=0.25)
  )
  (criterion): NLLLoss (
  )
)

So, module.encoder.embedding is an embedding layer, and module.encoder.embedding.weight represents the associated weight matrix. So, why it says- unexpected key "module.encoder.embedding.weight" in state_dict?

15 Likes

You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.

72 Likes

Yes, I used nn.DataParallel. I didn’t understand your second suggestion. Loading the weights file, create a new ordered dict without the module prefix and load it back. (Can you provide an example?)

Are you suggesting something like this? (example taken from here - https://github.com/OpenNMT/OpenNMT-py/blob/master/train.py)

model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict()
model_state_dict = {k: v for k, v in model_state_dict.items() if 'generator' not in k}
generator_state_dict = model.generator.module.state_dict() if len(opt.gpus) > 1 else model.generator.state_dict()
#  (4) drop a checkpoint
checkpoint = {
	'model': model_state_dict,
	'generator': generator_state_dict,
	'dicts': dataset['dicts'],
	'opt': opt,
	'epoch': epoch,
	'optim': optim
}
torch.save(checkpoint,
		   '%s_acc_%.2f_ppl_%.2f_e%d.pt' % (opt.save_model, 100*valid_acc, valid_ppl, epoch))  

May I ask you one question about the above code snippet, what is generator here?

2 Likes

I was thinking about something like the following:

# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
109 Likes

First of all, thanks a lot, by adding a nn.DataParallel temporarily in my network for loading purposes worked. Even I tried your second suggested approach, it worked for me as well. Thanks a lottt :slight_smile:

2 Likes

@wasiahmad By adding nn.DataParallel temporarily into your network did you have to have the same number of GPUs available to load the model as when you saved the model?

5 Likes

A related question, given the fact we see that saving DataParallel wrapped model can cause problems when the model_state_dict is loaded into an unwrapped model. Would one recommend to save the “unwrapped” ‘module’ field inside a DataParallel instance instead ?

here is our way for alexnet trained with pytorch examples imagenet:

2 Likes

This works for me. Thanks a lot !

I am having the same problem, and using the trick with OrderedDict does not work. I am using pytorch 0.3 in the case anything has changed.

I have an word embedding layer that was trained along with the classification task. Training was successful, but loading the model gave the error

Traceback (most recent call last): File "source/test.py", line 72, in <module> helper.load_model_states_from_checkpoint(model, args.save_path + 'model_best.pth.tar', 'state_dict', args.cuda) File "/u/flashscratch/flashscratch1/d/datduong/universalSentenceEncoder/source/helper.py", line 55, in load_model_states_from_checkpoint model.load_state_dict(checkpoint[tag]) File "/u/home/d/datduong/project/anaconda3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 490, in load_state_dict .format(name)) KeyError: 'unexpected key "embedding.embedding.embedding.weight" in state_dict'

The key embedding.embedding.embedding.weight exists (see image). Please let me know what to do.

In my opinion, this question-answer should be in something FAQ :slight_smile:

11 Likes

Check out your saved model file:

check_point = torch.load('myfile.pth.tar')
check_point.key()

You may find out your ‘check_point’ got several keys such as ‘state_dict’ etc.

checkpoint = torch.load(resume)
state_dict =checkpoint['state_dict']

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of dataparallel
    new_state_dict[name]=v

model.load_state_dict(new_state_dict)
7 Likes

What about nn.DistributedDataParallel, it seems DistributedDataParallel and DataParallel can load each other’s parameters.
Is there an official way to save/load among DDP/DP/None?

just do this:

model = torch.load(train_model)
…
net.load_state_dict(model[‘state_dict’])

it works for me!

Thanks a lot, this worked for me.

Instead of deleting the “module.” string from all the state_dict keys, you can save your model with:
torch.save(model.module.state_dict(), path_to_file)
instead of
torch.save(model.state_dict(), path_to_file)
that way you don’t get the “module.” string to begin with…

15 Likes

Thanks for your hints! It saved my time:rose:

that’s work simple and perfect for me! thanks

1 Like

In case someone needs, this function can handle loading weights w/ and w/o ‘module’.

To save model without ‘module’, you may try this.

1 Like

After pytorch 1.xx
this was fixed, now you only need to do this

            if isinstance(args.pretrained, torch.nn.DataParallel):
                args.pretrained = args.pretrained.module