def train(net, data, epochs=10, n_seqs=10, n_steps=50, lr=0.001, clip=5, val_frac=0.1, device=torch.device(‘cpu’),
name=‘checkpoint’, early_stop=True, plot=False):
# initialize the process group
"""
Training loop.
"""
net.train() # switch into training mode
opt = torch.optim.Adam(net.parameters(), lr=lr) # initialize optimizer
criterion = nn.CrossEntropyLoss() # initialize loss function
# create training and validation data
val_idx = int(len(data) * (1 - val_frac))
data, val_data = data[:val_idx], data[val_idx:]
# net = DDP(net)
# net = nn.DistributedDataParallel(net, device_ids=[0,1,2])
# net = torch.nn.DataParallel(net, device_ids=[0,1,2])
# net = torch.nn.DataParallel(net)
net.to(device) # move neural net to GPU/CPU memory
min_val_loss = 10.**10 # initialize minimal validation loss
train_history = {'epoch': [], 'step': [], 'loss': [], 'val_loss': []}
n_chars = len(net.chars) # get size of vocabulary
# main loop over training epochs
for e in range(epochs):
hidden = None # reste hidden state after each epoch
# loop over batches
for x, y in get_batches(data, n_seqs, n_steps):
# encode data and create torch-tensors
x = one_hot_encode(x, n_chars)
inputs, targets = torch.from_numpy(x).to(device), torch.tensor(y, dtype=torch.long).to(device)
# reset gradient information
net.module.zero_grad()
# generate network output
output, hidden = net.forward(inputs, hidden)
# compute loss
loss = criterion(output, (targets.view(n_seqs * n_steps)).type(torch.LongTensor))
# compute gradients
loss.backward()
# gradient clipping to prevent exploding gradients
nn.utils.clip_grad_norm_(net.module.parameters(), clip)
# optmize
opt.step()
# prevent backpropagating through the entire training history
# by detaching hidden state and cell state
hidden = (hidden[0].detach(), hidden[1].detach())
# validation step is done without tracking gradients
with torch.no_grad():
val_h = None
val_losses = []
for x, y in get_batches(val_data, n_seqs, n_steps):
x = one_hot_encode(x, n_chars)
inputs, targets = torch.from_numpy(x).to(device), torch.tensor(y, dtype=torch.long).to(device)
output, val_h = net.forward(inputs, val_h)
val_loss = criterion(output, (targets.view(n_seqs*n_steps)).type(torch.LongTensor))
val_losses.append(val_loss.item())
# compute mean validation loss over batches
mean_val_loss = np.mean(val_losses)
# track progress
train_history['epoch'].append(e+1)
train_history['loss'].append(loss.item())
train_history['val_loss'].append(mean_val_loss)
# print training progress
print("{} Epoch: {:.0f}/{:.0f} Loss: {:.4f} Val Loss: {:.4f}".format(
datetime.now().strftime('%H:%M:%S'),
e+1, epochs,
loss.item(),
mean_val_loss))
# save model checkpoint if validation loss has decreased
if mean_val_loss < min_val_loss:
save_checkpoint(net, opt, name+'.net', train_history=train_history)
min_val_loss = mean_val_loss
# if validation loss has not decreased for the last 10 epochs, stop training
if early_stop:
if e - np.argmin(train_history['val_loss']) > 10:
# display.clear_output()
print('Validation loss does not decrease further, stopping training.')
break