Job hangs at backward

pytorch 1.9 on A100. This is a large project and it is not easy to extract a reproducible script. The job runs well for several thousands of iterations, but it hangs at the backward.

here is the result of where under gdb

#0  0x00007f9219231376 in pthread_cond_wait@@GLIBC_2.3.2 () from /usr/lib/x86_64-linux-gnu/libpthread.so.0                                                                                 [323/1863]#1  0x00007f921509e4d1 in __gthread_cond_wait (__mutex=<optimized out>, __cond=<optimized out>)
    at /home/conda/feedstock_root/build_artifacts/gcc_compilers_1634095553113/work/build/x86_64-conda-linux-gnu/libstdc++-v3/include/x86_64-conda-linux-gnu/bits/gthr-default.h:865
#2  std::__condvar::wait (__m=..., this=<optimized out>)                                                                                                                                                 at /home/conda/feedstock_root/build_artifacts/gcc_compilers_1634095553113/work/build/x86_64-conda-linux-gnu/libstdc++-v3/include/bits/std_mutex.h:155                                            #3  std::condition_variable::wait (this=<optimized out>, __lock=...) at ../../../../../libstdc++-v3/src/c++11/condition_variable.cc:41                                                               #4  0x00007f91cf9ad66b in torch::autograd::ReadyQueue::pop() () from /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so                                                                #5  0x00007f91cf9b1cf1 in torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) () from /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so          #6  0x00007f91cf9ad3fb in torch::autograd::Engine::execute_with_graph_task(std::shared_ptr<torch::autograd::GraphTask> const&, std::shared_ptr<torch::autograd::Node>, torch::autograd::InputBuffer&&
) () from /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so                                                                                                                           #7  0x00007f91d61940fd in torch::autograd::python::PythonEngine::execute_with_graph_task(std::shared_ptr<torch::autograd::GraphTask> const&, std::shared_ptr<torch::autograd::Node>, torch::autograd::InputBuffer&&) () from /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_python.so                                                                                                          #8  0x00007f91cf9af184 in torch::autograd::Engine::execute(std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, bool, bool, std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&) () from /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so                  #9  0x00007f91d619404e in torch::autograd::python::PythonEngine::execute(std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&, std::vector<at::Tensor, std::allocator<at
::Tensor> > const&, bool, bool, bool, std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&) () from /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#10 0x00007f91d6194b76 in THPEngine_run_backward(_object*, _object*, _object*) () from /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#11 0x00005655023b2914 in cfunction_call_varargs (kwargs=<optimized out>, args=<optimized out>, func=<built-in method run_backward of torch._C._EngineBase object at remote 0x7f9209b2ff80>)             at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Objects/call.c:743                                                                                                 #12 PyCFunction_Call (func=<built-in method run_backward of torch._C._EngineBase object at remote 0x7f9209b2ff80>, args=<optimized out>, kwargs=<optimized out>)                                         at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Objects/call.c:773
#13 0x00005655023b5ebe in _PyObject_MakeTpCall (callable=<built-in method run_backward of torch._C._EngineBase object at remote 0x7f9209b2ff80>, args=<optimized out>, nargs=<optimized out>,            keywords=<optimized out>) at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Objects/call.c:159                                                                       #14 0x0000565502457c85 in _PyObject_Vectorcall (kwnames=('allow_unreachable', 'accumulate_grad'), nargsf=<optimized out>, args=<optimized out>,
    callable=<built-in method run_backward of torch._C._EngineBase object at remote 0x7f9209b2ff80>)                                                                                                     at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Include/cpython/abstract.h:125                                                                                     #15 _PyObject_Vectorcall (kwnames=('allow_unreachable', 'accumulate_grad'), nargsf=<optimized out>, args=<optimized out>,                                                                                callable=<built-in method run_backward of torch._C._EngineBase object at remote 0x7f9209b2ff80>)                                                                                                     at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Include/cpython/abstract.h:115                                                                                     #16 call_function (kwnames=('allow_unreachable', 'accumulate_grad'), oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=<optimized out>)
    at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Python/ceval.c:4963
#17 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Python/ceval.c:3515
#18 0x0000565502439433 in PyEval_EvalFrameEx (throwflag=0,
    f=Frame 0x7f91042d0440, for file /opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py, line 659, in backward (tensors=(<Tensor at remote 0x7f9104743bc0>,), grad_tensors=None, retai
n_graph=False, create_graph=False, grad_variables=None, inputs=(), grad_tensors_=(<Tensor at remote 0x7f9104b21340>,)))
    at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Python/ceval.c:741
#19 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=<optimized out>,
    kwargs=<optimized out>, kwcount=<optimized out>, kwstep=<optimized out>, defs=<optimized out>, defcount=<optimized out>, kwdefs=<optimized out>, closure=<optimized out>, name=<optimized out>,
    qualname=<optimized out>) at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Python/ceval.c:4298
#20 0x000056550243a818 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7f91043247a8, nargsf=<optimized out>, kwnames=<optimized out>)
    at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Objects/call.c:436
#21 0x0000565502453eb2 in _PyObject_Vectorcall (kwnames=('inputs',), nargsf=<optimized out>, args=<optimized out>, callable=<function at remote 0x7f915ab6ca60>)
    at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Include/cpython/abstract.h:127
#22 call_function (kwnames=('inputs',), oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=<optimized out>)
    at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Python/ceval.c:4963
#23 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /home/conda/feedstock_root/build_artifacts/python-split_1620713755437/work/Python/ceval.c:3515
#24 0x0000565502439433 in PyEval_EvalFrameEx (throwflag=0,
    f=Frame 0x7f9104324610, for file /opt/conda/lib/python3.8/site-packages/torch/_tensor.py, line 511, in backward (self=<Tensor at remote 0x7f9104743bc0>, gradient=None, retain_graph=False, creat
--Type <RET> for more, q to quit, c to continue without paging--

here is the result of py-bt under gdb

(gdb) py-bt
Traceback (most recent call first):
  <built-in method run_backward of torch._C._EngineBase object at remote 0x7f9209b2ff80>
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 659, in backward
  File "/opt/conda/lib/python3.8/site-packages/torch/_tensor.py", line 511, in backward
    return super(Tensor, self).split(split_size, dim)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 53, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/zero/stage2.py", line 2409, in backward
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2739, in backward

  File "/tmp/code/quickdetection/src/qd/opt/trainer.py", line 1621, in inner_loop_step
    self.model_engine.backward(loss)
  File "/tmp/code/quickdetection/src/qd/opt/trainer.py", line 1716, in inner_loop
    # not well tested
  File "/tmp/code/quickdetection/src/qd/opt/trainer.py", line 1827, in do
    if self.gradient_acc_end:
  File "/tmp/code/quickdetection/src/qd/pipelines/uni_pipeline.py", line 1900, in do_train

  File "/tmp/code/quickdetection/src/qd/pipelines/uni_pipeline.py", line 4597, in train
  File "/tmp/code/quickdetection/src/qd/pipelines/uni_pipeline.py", line 1664, in ensure_train
    ).do()
  File "/tmp/code/quickdetection/src/qd/pipeline.py", line 680, in pipeline_train_eval_multi
    pip.ensure_train()
  File "src/qd/common.py", line 3377, in execute_func
  File "src/qd/common.py", line 4416, in <module>