Handling signals in distributed train loop

I am going to train a model with several combinations of hyperparameters (make a hyperparameter search) and would like to terminate train loops that I consider unsuccessful, to save time.

I consider sending Unix signal to do it. However, is seems that torchrun intercepts all signals and converts them to SIGTERM

My code looks like following:

import signal
import torch

signal_received = False

def usr_signal_handler(signum, frame):
    global signal_received
    signal_received = True
    print('USR1 signal received')


signal.signal(signal.SIGUSR1, usr_signal_handler)

torch.distributed.init_process_group()

for _ in range(10):
    if int(os.environ['RANK']) == 0:
        p1, p2, p3 = generate_hyperparameters()
        torch.distributed.broadcast_object_list([p1, p2, p3], src=0)
    else:
        v = [None, None, None]
        torch.distributed.broadcast_object_list(v, src=0)
        p1, p2, p3 = v
    
    
    model = get_model(p1, p2, p3)

    for i in range(n_epochs):
        if signal_received:
            break
        dataloader = get_dataloader()
        for inputs, labels in dataloader:
            if signal_received:
                break
            outputs = model(inputs)
            loss = loss_func(outputs, labels)
           ....

torch.distributed.destroy_process_group()

then I run this code:

$ torchrun --nproc-per-node 2 --nnodes 1 --node-rank 0 --master-addr 192.168.1.123 --master-port 37176 train.py --other --parameters

Then I look at output of ps xf, find PIDs of processes, launched by torchrun, and send USR1 signal to the first one (I suppose, this process was granted rank 0):

$ kill -USR1 1048139

Result: all processes are killed with SIGTERM signal.

W0220 11:42:40.321000 1048055 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1048139 closing signal SIGTERM
W0220 11:42:40.324000 1048055 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1048140 closing signal SIGTERM
E0220 11:42:43.274000 1048055 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -10) local_rank: 0 (pid: 1048139) of binary: /home/user/.envs/env/bin/python3.11
Traceback (most recent call last):
File "/home/user/.envs/env/bin/torchrun", line 8, in <module>
sys.exit(main())
^^^^^^
File "/home/user/.envs/env/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/user/.envs/env/lib/python3.11/site-packages/torch/distributed/run.py", line 918, in main
run(args)
File "/home/user/.envs/env/lib/python3.11/site-packages/torch/distributed/run.py", line 909, in run
elastic_launch(
File "/home/user/.envs/env/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/.envs/env/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
=========================================================
train.py FAILED
---------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
---------------------------------------------------------
Root Cause (first observed failure):
[0]:
time      : 2025-02-20_11:42:40
host      : hostname
rank      : 0 (local_rank: 0)
exitcode  : -10 (pid: 1048139)
error_file: <N/A>
traceback : Signal 10 (SIGUSR1) received by PID 1048139
=========================================================

Is there any way to avoid this?
Studying sources of torch.distributed API has revealed that it installs handlers for SIGTERM, SIGINT, SIGHUP and SIGQUIT (that’s why I’ve decided to use SIGUSR1).

I should add that I have applied exactly this approach in single-GPU training, and it was working perfectly.