Naming .pth checkpoints using PyTorch Ignite ModelCheckpoint handler

I use the ModelCheckpoint handler in Ignite to save the model using these lines

ckpt_handler = ModelCheckpoint(f'./experiments/{log_num}/checkpoints', filename_prefix='epoch_', n_saved=None, create_dir=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), ckpt_handler, to_save)

but the .pth filename follow this pattern --> epoch__checkpoint_344.pth (344 refers to iterations).

I would to include the number of epoch into the filename… Any idea how’s that done?

@xen0f0n it can be done by adding a global_step_transform=lambda e, _: e.state.epoch

checkpointer = ModelCheckpoint(*args, **kwargs, global_step_transform=lambda e, _: e.state.epoch)
1 Like

Thank you @vfdev-5! Not sure how this works… I tried accessing trainer.state.epoch (before trainer.run) and state was None. Is this what that lambda does? Accessing trainer engine state after trainer.run?

Yes, state at the begining of the training is None as it is not defined.
When you attached

trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), ckpt_handler, to_save)

once every 2 epoch, ckpt_handler is triggered to save what to save. Its argument global_step_transform is an optional callable that

global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step.

So it is just executed during the training, where trainer’s state is initialized and can be used.

Please, see the docs for info : https://pytorch.org/ignite/handlers.html#ignite.handlers.Checkpoint

Hope this helps :slight_smile: