Hi. In a standard train script, I have the following partial function:
best_model_handler = partial(Checkpoint,
{"model": model},
DiskSaver(dirname=cfg.work_dir.as_posix(), require_empty=False),
filename_prefix="best",
n_saved=2,
global_step_transform=global_step_from_engine(trainer),
)
If I want to evaluate, I do this:
best_model_handler = best_model_handler(
score_name="val_bleu",
score_function=Checkpoint.get_default_score_fn("bleu"),
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
This works fine. On the other hand, when I try not to validate, I do the following:
best_model_handler = best_model_handler(
score_name="train_loss",
score_function=Checkpoint.get_default_score_fn("loss"),
)
trainer.add_event_handler(Events.COMPLETED, best_model_handler)
But it throws the following error:
File “/home/admin/anaconda3/envs/catbird/lib/python3.8/site-packages/ignite/handlers/checkpoint.py”, line 648, in wrapper
return score_sign * engine.state.metrics[metric_name]
KeyError: ‘loss’
However, if I do not define score_name
or score_function
, the script runs fine, but the model is not saved. How can I fix this? Thanks in advance for any help you can provide.