Hello everyone, good day
I’m using two datasets to training my model sequentially. The plan is to save checkpoint after training the model with the first training set. Then load dataset number two and the checkpoint to resume training from the last point as bellow article :
https://pytorch.org/tutorials/beginner/saving_loading_models.html
The problem is that the loss of validation going up in iteration number two. Kindly, any suggestions fix this problem?
The first iteration:
N_EPOCHS = 20
CLIP = 1
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iterator, criterion)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save({'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(), 'valid_loss' :valid_loss}, 'model.pt')
print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')
Load checkpoint:
checkpoint = torch.load('model.pt')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
valid_loss = checkpoint['valid_loss']
The second iteration:
N_EPOCHS = 10
CLIP = 1
best_valid_loss = checkpoint['valid_loss'] # float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss = train(model, train_iterator, optimizer, criterion, CLIP) # here load model.train()
valid_loss = evaluate(model, valid_iterator, criterion) # here load model.eval()
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save({'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(), 'valid_loss' : valid_loss
}, 'model.pt')
print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')