Saving and restoring model weights at the mini-batch level?

In the event that a classification model is being trained on a large amount of data (~3,000,000 input images per epoch), what is the recommended approach for implementing checkpoint-like functionality at the mini-batch level, instead of the epoch level (as shown here)? Can anyone recommend a way to save the weights and gradients after every x mini-batches (instead of every x epochs)? Any code snippets, or MCVE’s would be greatly appreciated. I am not running out of GPU memory. However, prior to code profiling I am experiencing hours worth of training time for a single epoch (in part due to the size of the training data). So this is the motivation for saving and restoring the training process at the mini-batch level prior to an entire epoch being completed.

I have read the documentation for checkpoints which seems to suggest that this is possible on the mini-batch level. I have also read the discussion on torch.utils.checkpoint.checkpoint which seems to suggest that this is not the proper use of a checkpoint. Recommendations for the appropriate methodology (if available) are greatly appreciated! I have experimented a little bit with forward and backward hooks and can see how this might be used to accomplish the functionality I desire, but I am wondering if there is a better alternative.

Can I just deepcopy the: model.state_dict(), epoch, mini_batch_number, mini_batch_size, and optimizer.state_dict() after x mini-batch’s and save and restore the checkpoint as normal? Or is there anything else to watch out for?

I am relatively new to PyTorch and certainty new to the forums, so please let me know if this is the improper place to ask this question, or if this post can be improved. Constructive criticism is always appreciated!

1 Like


If what you want is to save the current state of your model/optim such that you can restart from there if needed, the checkpoint tool is not what you want. neither are hooks.
You simply want to implement the same thing that is done to save every x epochs and put it in your inner loop that iterates over mini-batches such that it runs every x mini-batches.

1 Like

Thank you for your quick response. So the ImageNet Example on GitHub is misusing checkpoints? See:

What I am really trying to ask is, can I do what is shown here:

            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best)


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

and resuming from the checkpoint:

 if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            print("=> no checkpoint found at '{}'".format(args.resume))

But with the mini-batch in addition to the epoch? If this isn’t the correct use of a checkpoint, what method of serialization should I use to save and restore the state (gradients, weights, epoch_num, minibatch_num) during training?

This is the correct way to save and restore your model’s parameters.
The name checkpoint can be a bit misleading, as it can refer to your model’s state of parameters and the torch.utils.checkpoint (link), which is used to trace memory for compute. This is, what @albanD meant.

Ok, great. Thank you for clarifying! I didn’t realize torch.utils.checkpoint was used for tracing memory. It is a bit of a misnomer then. I think the first definition is consistent with other machine learning frameworks (at least with TensorFlow).

So is there a way to save and restore the mini-batch index along with the epoch index? I doubt it is as simple as:

            'epoch': epoch + 1,
            'mini_batch': minibatch + 1,
        }, is_best)

Is there functionality in the model or the optimizer to load a minibatch size and index alongside the epoch index?

Neither the model nor the optimizer use the epoch or batch count.
It’s up to you to store and load the model checkpoints.
The model and optimizer just use .load_state_dict() to restore their parameters.

If you can work with epoch + minibatch, it’s fine. Alternatively you could somehow store the epoch as a fractal, i.e. for half the mini-batches in the epoch, just save 'epoch': 0.5.

Good to know, thank you @ptrblck and @albanD for addressing my misconceptions. For others searching for this, my take away from this the discussion is as follows (please correct me if I’m wrong):

  • There is currently no existing module in the PyTorch library for saving and restoring the model during training. The functionality is there, but It is up to the user to implement how to resume the training of a model.
  • PyTorch does provide a way to save and restore the model’s parameters through the load_state_dict() method.
    • The same approach works for the optimizer’s gradients, and parameters.
  • Although serialization methods do exist, they are intended for use with a trained model; or require additional independent logic to save and resume progress during training.
  • Although there is a torch.utils.checkpoint it is NOT intended for the same use case scenario as TensorFlow’s similarly named checkpoints, and is instead used to provide memory traces during computation.