Save/Load Fails in Model with LSTM num_layers > 1

I have been experimenting with a model for forecasting, which is composed like so:

# nn.Model, etc.
		self.lstm = nn.LSTM(input_size=12, hidden_size=6,
      num_layers=3, batch_first=True)
		self.linear = nn.Linear(6, 2)
		self.output_layer = nn.Linear(2, 1)
torch.save({

  'epoch': 1,

  'model_state_dict': model.state_dict(),

  'optimizer_state_dict': optimizer.state_dict(),

  'loss': loss

}, 'model.tar')

The model trains, evaluates and saves fine. However, it fails when I attempt to load it as per the docs:


model = LSTModel()
optimizer = torch.optim...

checkpoint = torch.load('onelayer')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()

I receive the following when attempting to load the model’s state dict:

Error(s) in loading state_dict for ERNN:
Missing key(s) in state_dict: “lstm.weight_ih_l1”, “lstm.weight_hh_l1”, “lstm.bias_ih_l1”, “lstm.bias_hh_l1”, “lstm.weight_ih_l2”, “lstm.weight_hh_l2”, “lstm.bias_ih_l2”, “lstm.bias_hh_l2”, “linear.weight”, “linear.bias”, “output_layer.weight”, “output_layer.bias”.
Unexpected key(s) in state_dict: “mapping.weight”, “mapping.bias”, “timeLayer.weight”, “timeLayer.bias”.

I see these keys are indeed missing, is there something I am doing wrong when saving the model?

I guess you might load the wrong checkpoint, since your model’s state_dict can be stored and loaded successfully:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.lstm = nn.LSTM(
            input_size=12, hidden_size=6, num_layers=3, batch_first=True)
        self.linear = nn.Linear(6, 2)
        self.output_layer = nn.Linear(2, 1)

model = MyModel()
sd = model.state_dict()

model = MyModel()
model.load_state_dict(sd)
> <All keys matched successfully>

Are the unexpected keys familiar looking? E.g. did you use mapping and timeLayer in another script, which might have overwritten the needed checkpoint?

1 Like