Resume training - validation loss going up - (increased)

Hello everyone, good day :slight_smile:

I’m using two datasets to training my model sequentially. The plan is to save checkpoint after training the model with the first training set. Then load dataset number two and the checkpoint to resume training from the last point as bellow article :

https://pytorch.org/tutorials/beginner/saving_loading_models.html

The problem is that the loss of validation going up in iteration number two. Kindly, any suggestions fix this problem?

The first iteration:

N_EPOCHS = 20
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
       
        torch.save({'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(), 'valid_loss' :valid_loss}, 'model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Load checkpoint:

checkpoint = torch.load('model.pt')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch'] 
valid_loss = checkpoint['valid_loss']

The second iteration:

N_EPOCHS = 10
CLIP = 1

best_valid_loss = checkpoint['valid_loss'] # float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP) # here load  model.train()
    valid_loss = evaluate(model, valid_iterator, criterion)  # here load  model.eval()
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
       
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
 
        torch.save({'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(), 'valid_loss' : valid_loss
                     }, 'model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Assuming the codes are correct in train and evaluate I cannot see any obvious error.
Could you try to post an executable code snippet (using maybe random inputs), which could reproduce this issue, please?

1 Like

Okay sir, first I appreciate your support.

I’m using the basic transformer with NMT for automatic grammar error correction of Arabic language.

My code and model architecture is the same as the articale:

The few things I have changed are to use 'bert-base-multilingual-cased' Bert tokenizer and uploaded the first dataset. 

from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

SRC = Field(tokenize = tokenizer.tokenize, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True, fix_length = 100)

TRG = Field(tokenize = tokenizer.tokenize, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True, 
            batch_first = True, fix_length = 100)

train_data, valid_data = TabularDataset.splits(path='./data/',train='fluent_train_8.csv',
    validation='fluent_vali_8.csv', format='csv', fields=[('src', SRC), ('trg', TRG)], skip_header=True)
train_iterator, vali_iterator, test_iterator = BucketIterator.splits((train_data, valid_data, test_data), sort_key=lambda x: len(x.src),
     batch_size = BATCH_SIZE, device = device)

After 20 epochs as above code: (Everything worked well and expected)

  1. Upload the second dataset:
train_data, valid_data, test_data = TabularDataset.splits(path='./data/', train='train.csv',
                                    validation='devlop.csv', test='test.csv', format='csv',
                                    fields=[('src', SRC), ('trg', TRG)], skip_header=True)
train_iterator, vali_iterator, test_iterator = BucketIterator.splits((train_data, valid_data, test_data), sort_key=lambda x: len(x.src),
     batch_size = BATCH_SIZE, device = device)
  1. Uploaded the checkpoint:
checkpoint = torch.load('model.pt')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])           
epoch = checkpoint['epoch'] 
chk_train_loss = checkpoint['chk_train_loss']
chk_valid_loss = checkpoint['chk_valid_loss']

Unfortunately, dataset number one is huge, but the executed code is here.

Could you try to create a small dummy dataset with just a few samples?
If the loading of the checkpoint is broken, the loss jump should also be visible using a random dataset as long as the model trains.

1 Like

I have applied the same technique with a single dataset, everything worked well.

With 20% of the amount of the two datasets, the frequency is low but the results still un-stable.

With 100% of the full datasets the validation loss jumping up in each epoch.

I have concluded that this issue is related to how the model dealing two datasets.

I’m still confused, why this happened and how I can fix it?

as a check, set the model in the validation script in train mode (net.train() ) instead of net.eval(). If the loss does NOT go up, then the problem is most likely batchNorm. This happens more than anyone would think.

1 Like

Thank you sir, this issue is almost related to differences between the two datasets.