I am finetuning the huggingface implementation of bert on glue tasks. I did two experiments. In the first one, I finetune the model for 3 epochs and then evaluate. In the second, I implemented early stopping: I evaluate on the validation set at the end of each epoch to decide whether to stop training. I print the training loss every 500 steps. In the first epoch, the loss from two experiments are exactly the same as the training proceeds. Right after the first evaluation in the second experiment, the model in the second experiment starts to perform worse, resulting in higher training loss and worse validation performance. Does that have anything to do with Batch Norm? Is there a way to implement early stopping without harming performance?
The train and evaluate functions are here:
https://github.com/huggingface/transformers/blob/master/examples/run_glue.py
In the second experiment I added the following at the end of each training epoch.
if args.local_rank in [-1, 0] and args.early_stopping > 0: logs = {} if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well results = evaluate(args, model, tokenizer) for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value if "f1" in results: metric = "f1" elif "acc" in results: metric = "acc" elif "mcc" in results: metric = "mcc" else: metric = "corr" if results[metric] > best_eval_result: best_eval_result = results[metric] num_degrading_epoch = 0 if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) torch.save(args, os.path.join(args.output_dir, "training_args.bin")) torch.save(optimizer.state_dict(), os.path.join(args.output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(args.output_dir, "scheduler.pt")) else: num_degrading_epoch += 1