Segfault in autograd after using torch lightning

I am stuck trying to understand and fix my problem. I have a model that trains successfully (i.e. without errors) with manual for loop. However, when I implemented training via lightning, I get a segmentation fault at the end of the first batch.

CUDA                      12.4
torch                     2.6.0+cu124
pytorch-lightning         2.5.1.post0

I have gdb backtrace which I can reproduce, but cannot understand

Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x00007fffd076acb2 in std::__detail::__variant::__gen_vtable_impl<true, std::__detail::__variant::_Multi_array<std::__detail::__variant::__variant_cookie (*)(std::__detail::__variant::_Variant_storage<false, c10::SmallVector<c10::SymInt, 5u>, at::Tensor>::_M_reset_impl()::{lambda(auto:1&&)#1}&&, std::variant<c10::SmallVector<c10::SymInt, 5u>, at::Tensor>&)>, std::tuple<std::variant<c10::SmallVector<c10::SymInt, 5u>, at::Tensor>&>, std::integer_sequence<unsigned long, 0ul> >::__visit_invoke(std::__detail::__variant::_Variant_storage<false, c10::SmallVector<c10::SymInt, 5u>, at::Tensor>::_M_reset_impl()::{lambda(auto:1&&)#1}&&, std::variant<c10::SmallVector<c10::SymInt, 5u>, at::Tensor>&) ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
(gdb) bt
#0  0x00007fffd076acb2 in std::__detail::__variant::__gen_vtable_impl<true, std::__detail::__variant::_Multi_array<std::__detail::__variant::__variant_cookie (*)(std::__detail::__variant::_Variant_storage<false, c10::SmallVector<c10::SymInt, 5u>, at::Tensor>::_M_reset_impl()::{lambda(auto:1&&)#1}&&, std::variant<c10::SmallVector<c10::SymInt, 5u>, at::Tensor>&)>, std::tuple<std::variant<c10::SmallVector<c10::SymInt, 5u>, at::Tensor>&>, std::integer_sequence<unsigned long, 0ul> >::__visit_invoke(std::__detail::__variant::_Variant_storage<false, c10::SmallVector<c10::SymInt, 5u>, at::Tensor>::_M_reset_impl()::{lambda(auto:1&&)#1}&&, std::variant<c10::SmallVector<c10::SymInt, 5u>, at::Tensor>&) ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#1  0x00007fffd077272d in torch::autograd::Node::~Node() ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#2  0x00007fffd07a7a84 in torch::autograd::generated::SelectBackward0::~SelectBackward0() ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#3  0x00007fffd1407f54 in torch::autograd::deleteNode(torch::autograd::Node*) ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#4  0x00007fffd0b81d4e in std::_Sp_counted_deleter<torch::autograd::generated::SelectBackward0*, void (*)(torch::autograd::Node*), std::allocator<void>, (__gnu_cxx::_Lock_policy)2>::_M_dispose() () from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#5  0x00007fffd1407fa9 in torch::autograd::deleteNode(torch::autograd::Node*) ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#6  0x00007fffd0b81d4e in std::_Sp_counted_deleter<torch::autograd::generated::SelectBackward0*, void (*)(torch::autograd::Node*), std::allocator<void>, (__gnu_cxx::_Lock_policy)2>::_M_dispose() () from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#7  0x00007fffd141c8f6 in torch::autograd::CopyBackwards::~CopyBackwards() ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#8  0x00007fffd141de74 in torch::autograd::CopySlices::~CopySlices() ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#9  0x00007fffd141ddd6 in torch::autograd::CopySlices::~CopySlices() ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#10 0x00007fffd141ddd6 in torch::autograd::CopySlices::~CopySlices() ()

...........(86k frames later..........)

#86352 0x00007fffd13e7fd0 in torch::autograd::AutogradMeta::~AutogradMeta() ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so
#86353 0x00007fff96818332 in c10::TensorImpl::~TensorImpl() ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libc10.so
#86354 0x00007fff968184e9 in c10::TensorImpl::~TensorImpl() ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libc10.so
#86355 0x00007fffe1525d38 in THPVariable_subclass_clear(THPVariable*) ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_python.so
#86356 0x00007fffe1526090 in THPVariable_subclass_dealloc(_object*) ()
   from /gpfs/projects/b1038/Pulmonary/nmarkov/scythe/venv/lib/python3.12/site-packages/torch/lib/libtorch_python.so
#86357 0x0000555555672f61 in _PyEval_EvalFrameDefault (tstate=tstate@entry=0x555555bdbe90 <_PyRuntime+458992>, frame=0x7ffff7e805a0,
    frame@entry=0x7ffff7e80418, throwflag=throwflag@entry=0) at Python/bytecodes.c:209
#86358 0x00005555557993ac in _PyEval_EvalFrame (throwflag=0, frame=0x7ffff7e80418, tstate=0x555555bdbe90 <_PyRuntime+458992>)
    at /usr/local/src/conda/python-3.12.10/Include/internal/pycore_ceval.h:89
#86359 _PyEval_Vector (kwnames=<optimized out>, argcount=<optimized out>, args=0x7ffe7ae1bf30, locals=0x0, func=0x7ffe7b598f40,
    tstate=0x555555bdbe90 <_PyRuntime+458992>) at /usr/local/src/conda/python-3.12.10/Python/ceval.c:1685
#86360 _PyFunction_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, stack=0x7ffe7ae1bf30, func=0x7ffe7b598f40)
    at /usr/local/src/conda/python-3.12.10/Objects/call.c:419
#86361 _PyObject_VectorcallTstate (tstate=0x555555bdbe90 <_PyRuntime+458992>, callable=0x7ffe7b598f40, args=0x7ffe7ae1bf30, nargsf=<optimized out>,
    kwnames=<optimized out>) at /usr/local/src/conda/python-3.12.10/Include/internal/pycore_call.h:92
#86362 0x0000555555798f0e in method_vectorcall (method=method@entry=0x7ffe7ae06200, args=args@entry=0x7ffe7ae1bf38, nargsf=<optimized out>,
    kwnames=kwnames@entry=0x7ffe7ae2c4f0) at /usr/local/src/conda/python-3.12.10/Objects/classobject.c:61
#86363 0x000055555578129b in _PyVectorcall_Call (kwargs=<optimized out>, tuple=<optimized out>, callable=0x7ffe7ae06200,
    func=0x555555798c30 <method_vectorcall>, tstate=0x555555bdbe90 <_PyRuntime+458992>) at /usr/local/src/conda/python-3.12.10/Objects/call.c:283
#86364 _PyObject_Call (tstate=0x555555bdbe90 <_PyRuntime+458992>, callable=0x7ffe7ae06200, args=<optimized out>, kwargs=<optimized out>)
    at /usr/local/src/conda/python-3.12.10/Objects/call.c:354
...

I would appreciate any help figuring this out.

Do you see the same issue using the latest stable or nightly release?