Hello I often have power cuts at home, I wish to be able to continue the training of an LSTM with pytorch after a cut. Is this code suitable for reloading the model and continuing the training afterwards?
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from model import Model
from dataset import Dataset
def save_checkpoint(dataset, model, args):
# Save the model state dict
model_state = model.state_dict()
# Save the LSTM cell states
lstm_states = []
for name, param in model_state.items():
if 'lstm_cell' in name:
lstm_states.append(param.state_dict())
# Save the optimizer state
optimizer_state = optim.Adam.state_dict()
# Save the checkpoint
torch.save_pretrained(model_state, lstm_states, optimizer_state, ' filename.pth.tar')
def load_checkpoint(dataset, model, args):
model = Model(dataset)
torch.load('filename.pth.tar')
model.eval()
return dataset, model, args
def train(dataset, model, args):
model.train()
dataloader = DataLoader(dataset, batch_size=args.batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(args.max_epochs):
state_h, state_c = model.init_state(args.sequence_length)
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred.transpose(1, 2), y)
state_h = state_h.detach()
state_c = state_c.detach()
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
save_checkpoint(dataset, model, args)
def predict(dataset, model, text, next_words=100):
model.eval()
words = text.split(' ')
state_h, state_c = model.init_state(len(words))
for i in range(0, next_words):
x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
word_index = np.random.choice(len(last_word_logits), p=p)
words.append(dataset.index_to_word[word_index])
return words
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=200)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=3)
args = parser.parse_args()
dataset = Dataset(args)
model = Model(dataset)
#load_checkpoint(dataset, model, args)
train(dataset, model, args)
# to store
torch.save({
'state_dict': model.state_dict()
}, 'filename.pth.tar')
If so, how can I load the model? I don’t know how to load optimizers.
device = torch.device('cpu')
model = Model(dataset)
torch.load('filename.pth.tar', map_location=device)