Unexpected NaN in forward pass when using flash_attn

I`m using flash_attn as following:

                q, q_indices, cu_q_lens, q_max_s = unpad_input(q, x_mask)
                kv, _, cu_kv_lens, kv_max_s = unpad_input(kv, source_mask)
                q, kv = q.view(-1, self.nhead, self.head_dim).bfloat16(), kv.view(-1, 2, self.nhead, self.head_dim).bfloat16() # [N_unmasked, 1(2), H, D]
                dropout_p = 0.1 if self.training else 0.0
                message = flash_attn_varlen_kvpacked_func(
                    q, kv, cu_q_lens, cu_kv_lens, q_max_s, kv_max_s, dropout_p, softmax_scale=None)
                 check_nan(message, 'message1_attn_self', {'q': q, 'kv': kv, 'cu_q_lens': cu_q_lens, 'cu_kv_lens': cu_kv_lens, 'q_max_s': q_max_s, 'kv_max_s': kv_max_s, 'x_mask': x_mask, 'source_mask': source_mask}) # ignore this line for now

after a few steps of training, the message contains NaN, confirmed by the func check_nan:

def check_nan(x, name, save_dict={}):
    if isinstance(x, list) or isinstance(x, tuple):
        for xx in x:
    if torch.isnan(x).any():
        save_dict[name] = x
        torch.save(save_dict, 'nan_check_tr.pt')
        raise ValueError(f'nan detected in {name}')

The code is running on a remote cluster, so I can not debug it directly. So I saved the context vars when NaN is detected.

However, when I try to use the saved nan_check_tr.pt to reproduce the process, there is no NaN.
pt = torch.load(‘nan_check_tr.pt’, map_location=‘cuda’)
m = flash_attn_varlen_kvpacked_func(pt[‘q’], pt[‘kv’], pt[‘cu_q_lens’], pt[‘cu_kv_lens’], pt[‘q_max_s’], pt[‘kv_max_s’], softmax_scale=None)

>>> tensor(False, device='cuda:0')

I`m sure the environment I used is totally the same(the same docker image). I have no idea what is going on. I have tried different learning rate such as 5e-4, 3e-4, 1e-4. A small lr can train more steps without NaN, but it will be there anyway.

The saved file is here nan_check_tr.pt

Some important environment:

docker built on: nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04

torch: 1.13.1

flash-attn: 2.2.2

pytorch-lightning: 1.9.0

GPU: one single A100(debug)/two nodes with 8 A100 each(training)

Any suggestion is welcome! Thx!