def train(net, data, model_name, batch_size=10, seq_length=50, lr=0.001, clip=5,
print_every_n_step=50, save_every_n_step=5000):
net.train()
opt = torch.optim.Adam(net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
if train_on_gpu:
net.cuda()
n_chars = len(net.chars)
counter = epoch = 0
loss_history = []
start = time.time()
while True:
epoch += 1
h = net.init_hidden(batch_size)
for x, y in get_batches(data, batch_size, seq_length):
counter += 1
x = one_hot_encode(x, n_chars)
inputs, targets = torch.from_numpy(x), torch.from_numpy(y)
if train_on_gpu:
inputs, targets = inputs.cuda(), targets.cuda()
h = tuple([each.data for each in h])
net.zero_grad()
output, h = net(inputs, h)
loss = criterion(output,targets.view(batch_size*seq_length) )
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), clip)
opt.step()
# loss stats
if counter % print_every_n_step == 0:
print(f"Epoch: {epoch:5} | Step: {counter:6} | Loss: {loss.item():.4f} | Elapsed Time: {time_since(start)}")
if counter % save_every_n_step == 0:
print(f"Epoch: {epoch:5} | Step: {counter:6} | Loss: {loss.item():.4f} | Elapsed Time: {time_since(start)}")
print(" --- Save checkpoint ---")
checkpoint = {
'n_hidden': net.n_hidden,
'n_layers': net.n_layers,
'state_dict': net.state_dict(),
'tokens': net.chars,
'loss_history': loss_history
}
torch.save(checkpoint, open(f"{model_name}/epoch_{epoch}.pth", 'wb'))
loss_history.append(loss.item())