Having started playing around with Pytorch, I thought about neat workflow for playing around with learning rates.
I’d set up a training loop, run it inside a Jupyter notebook and watch the loss plot updated in real time (using IPython.display.clear_output
and display
to redraw the plot). Once the LR stops going down, I’d hit I I
to interrupt the kernel, go back to the cell that defines the optimizer, create a new one with a different learning rate, and re-run my training loop on the same model, continuing where it left off.
Since the loss values are stored in a separate array, the plot would just keep going, the model instance is also kept around as only the optimizer changes, without the need for any checkpoint saving/loading/lr decay/etc. or any other programmer overhead.
Now the question is, is it safe to interrupt training (.backward()
and .step()
) at arbitrary points in time with a KeyboardInterrupt
exception, or am I risking a possibly corrupt state somewhere deep down as the model is being updated?
If so, what would you guys recommend to make this safe? The only thing I’ve thought of so far is to wrap the batch in a try/catch
so that the interrupt doesn’t stop in the middle of the minibatch, but still stops training in a safe way. Do you guys do this, or have any other tips for interactively playing with a network during training?
Coming from TensorFlow, an interactive workflow like this feels like a breath of fresh air.