tf.train.Checkpoint
and tf.train.CheckpointManager
is incredibly useful for training and saving models.
The API works like this:
ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer, step=step)
manager = tf.train.CheckpointManager(ckpt, 'ckpt', max_to_keep=3, keep_checkpoint_every_n_hours=1)
# train loop
for x, y in dataset:
...
manager.save()
Is there anything or any library with similar functionality in PyTorch? I understand how to save things using a state_dict
but would like to know if there’s a method that’s as simple as TF’s Checkpoint
.