Pytorch flash attention error: RuntimeError: !grad_accumulator_.expired() INTERNAL ASSERT FAILED

Error reported when using flash attention 2 from transfomers with pytorch.

File "/opt/ml/code/train_fsdp.py", line 614, in main
    loss.backward()
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 529, in backward
    q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
RuntimeError: !grad_accumulator_.expired() INTERNAL ASSERT FAILED at "../torch/csrc/autograd/saved_variable.cpp":226, please report a bug to PyTorch. No grad accumulator for a saved leaf

Using

pip==23.3.1
torch==2.0.1
transformers==4.34.1
tokenizers>=0.13.3
sentencepiece>=0.1.99
rouge_score==0.1.2
accelerate>=0.20.3
optimum>=1.11.1
tqdm>=4.54.1
pyarrow>=10.0.1
datasets>=1.18.2
xformer>=1.0.1
packaging>=23.1
ninja>=1.11.1
flash-attn>=2.3.3

Tried Pytorch 2.1.0, more errors came out.