Be able to continue the training of an LSTM with pytorch

Hello I often have power cuts at home, I wish to be able to continue the training of an LSTM with pytorch after a cut. Is this code suitable for reloading the model and continuing the training afterwards?

import argparse
import torch
import numpy as np
from torch import nn, optim
from import DataLoader
from model import Model
from dataset import Dataset

def save_checkpoint(dataset, model, args):
    # Save the model state dict
    model_state = model.state_dict()
    # Save the LSTM cell states
    lstm_states = []
    for name, param in model_state.items():
        if 'lstm_cell' in name:
    # Save the optimizer state
    optimizer_state = optim.Adam.state_dict()
    # Save the checkpoint
    torch.save_pretrained(model_state, lstm_states, optimizer_state, ' filename.pth.tar')

def load_checkpoint(dataset, model, args):
    model = Model(dataset)
    return dataset, model, args

def train(dataset, model, args):
    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(args.max_epochs):
        state_h, state_c = model.init_state(args.sequence_length)
        for batch, (x, y) in enumerate(dataloader):
            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)
            state_h = state_h.detach()
            state_c = state_c.detach()
            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
        save_checkpoint(dataset, model, args)

def predict(dataset, model, text, next_words=100):
    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
    return words

parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=200)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=3)
args = parser.parse_args()
dataset = Dataset(args)
model = Model(dataset)
#load_checkpoint(dataset, model, args)
train(dataset, model, args)

# to store{
    'state_dict': model.state_dict()
}, 'filename.pth.tar')

If so, how can I load the model? I don’t know how to load optimizers.

device = torch.device('cpu')
model = Model(dataset)
torch.load('filename.pth.tar', map_location=device)   

Please see ModelCheckpoint

Hello here is the code of the model. If that is your request, otherwise I did not understand.

import torch
from torch import nn
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3
        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
        self.lstm = nn.LSTM(
        self.fc = nn.Linear(self.lstm_size, n_vocab)
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state
    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

I shared a link with you to make use of a pytorch lightning callback for saving model weights periodically. It’ll help you when there’s a power failure. You can load the latest weights and continue training.
Did you read the link I gave you?

Ah, excuse me, I didn’t see it was a link, I’ll see that now. Thank you!

Hello I am trying to install lightning because the execution of this code does not yet work on my workstation.

from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(dirpath="my/path/", save_top_k=2, monitor="val_loss")
trainer = Trainer(callbacks=[checkpoint_callback])

With pip install, installing git to install it from github directly… For the moment I do not have many results… I’ll get back to you soon.