Equivelent or analogous pytorch version of `tf.train.Checkpoint`

tf.train.Checkpoint and tf.train.CheckpointManager is incredibly useful for training and saving models.

The API works like this:

ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer, step=step)
manager = tf.train.CheckpointManager(ckpt, 'ckpt', max_to_keep=3, keep_checkpoint_every_n_hours=1)

# train loop
for x, y in dataset:

Is there anything or any library with similar functionality in PyTorch? I understand how to save things using a state_dict but would like to know if there’s a method that’s as simple as TF’s Checkpoint.

If you don’t want to manually store the state_dicts, then I would recommend to check out some higher-level APIs such as Ignite and Lightning, which both have their checkpointing utility functions.

CC @vfdev-5 and @williamFalcon for more information.

@Nathan_Wood concerning Ignite, we have Checkpoint class that can be used to store 1) training checkpoints with model, optimizer, lr scheduler etc and 2) best N models.

Complete example of usage can be found here.


PS: Please, take a look and feel free to leave us a feedback :slight_smile: