Signal handling and torch.save()

Hi,

I am thinking of sginal handler, signal.signal()m to catch a certain range of signals like SIGINT etc. The idea is that this handler would save out a checkpoint of the current state of a model when a signal event happens that would eventually lead to termination of the current process. The concern is that this would happen, say, half-way through backpropagation leaving part of the weights and biases not updated/in an inconsistent state.

So I wonder if that is a sensible strategy at all. What can I expect to happen in case of such events? Is there a difference between CPU and GPU? How would I write out checkpoints if a user presses Ctrl-C?

Many thanks.

Hey, wondering the exact same thing right now, did you ever find an answer or tried it and it always worked?

This RFC has a list of hook APIs that are relevant along with some suggested improvements [RFC] Consolidated and unified state_dict and load_state_dict hooks · Issue #75287 · pytorch/pytorch · GitHub

Hi Mark,

Thank you for your reply! Unfortunately, I’m still not sure I understand how to handle it.
As I understand your linked post, it discusses various hooks that allow inserting computation before the state_dict is computed.

What I’m wondering is whether a signal handler for SIGTERM (registered with signal.signal()), could see the state_dict in a corrupted state (for example when SIGTERM is sent during the weight update step), hence making it a bad idea to save the state_dict to disk in the signal handler before exiting the program.

For context: This is relevant e.g. for SLURM managed clusters.

Thank you for your help!
Best, Max