Problem with model accuracy (after restore) on TPU

I’m training GPT-2 from huggingface/transformers on TPU. It’s training well. At the end of a training I’ve got loss around 4.36. When I save and restore the model - the loss skyrockets somewhere to 9.75.

I’ve got no similar issues with saving and loading on GPU with that code.

The code what is used to save is just this

xm.save(model_to_save.state_dict(), output_model_file)

xm.save is a convinience what moves tensors from TPU to CPU before saving.

The whole code is here ru_transformers/tpu_lm_finetuning.py at master · mgrankin/ru_transformers · GitHub

I’ve tried to do the following.

  1. I’ve tried to do save and load right after the training
    results = evaluate(args, model, tokenizer, "checkpoint-0", False)
    log_info(f"Eval1 {results}")
    
    model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
    model.to(args.device)
    results = evaluate(args, model, tokenizer, "checkpoint-0", False)
    log_info(f"Eval2 {results}")

Eval2 is much bigger that Eval1

  1. I also tried not to recreate the model, but to replace model state_dict with saved state_dict
    results = evaluate(args, model, tokenizer, "checkpoint-0", False)
    log_info(f"Eval1 {results}")
    
    model.load_state_dict(torch.load('output/classic_s/pytorch_model.bin'))
    model.to(args.device)
    results = evaluate(args, model, tokenizer, "checkpoint-0", False)
    log_info(f"Eval2 {results}")

In that case Eval2 is equal to Eval1

So, there is something that isn’t in a state_dict, but it affects the model perfomance. What can that be?

Skimming through your code, it looks like you are using AdamW as your optimizer, which uses internal states. To be able to properly resume your training, you should also store/restore the optimizer’s state_dict.

2 Likes

Thank you for a valuable advice on saving the state of AdamW. But that is not the root problem here. I don’t run train after save/load and it’s performing way worse right after loading.

Davide Libenzi is trying hard to help me with the issue here https://github.com/pytorch/xla/issues/1245

The issue is mostly resolved here