In the event that a classification model is being trained on a large amount of data (~3,000,000 input images per epoch), what is the recommended approach for implementing checkpoint-like functionality at the mini-batch level, instead of the epoch level (as shown here)? Can anyone recommend a way to save the weights and gradients after every
x mini-batches (instead of every
x epochs)? Any code snippets, or MCVE’s would be greatly appreciated. I am not running out of GPU memory. However, prior to code profiling I am experiencing hours worth of training time for a single epoch (in part due to the size of the training data). So this is the motivation for saving and restoring the training process at the mini-batch level prior to an entire epoch being completed.
I have read the documentation for checkpoints which seems to suggest that this is possible on the mini-batch level. I have also read the discussion on torch.utils.checkpoint.checkpoint which seems to suggest that this is not the proper use of a checkpoint. Recommendations for the appropriate methodology (if available) are greatly appreciated! I have experimented a little bit with forward and backward hooks and can see how this might be used to accomplish the functionality I desire, but I am wondering if there is a better alternative.
Can I just
x mini-batch’s and save and restore the checkpoint as normal? Or is there anything else to watch out for?
I am relatively new to PyTorch and certainty new to the forums, so please let me know if this is the improper place to ask this question, or if this post can be improved. Constructive criticism is always appreciated!