-inf appears in the attention weight of function: scaled_dot_product_attention

I use the equivalent python code from PyTorch official document to instead nn.functional.scaled_dot_product_attention.

def scaled_dot_product_attention_manual(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        attn_bias = torch.zeros(attn_mask.shape, dtype=query.dtype, device=query.device)
        if is_causal:
            assert attn_mask is None
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias += attn_mask
        attn_weight = query @ key.transpose(-2, -1) * scale_factor
            
        attn_weight += attn_bias.to(query.device)
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_logits = attn_weight

        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
        return attn_weight @ value, attn_logits

But the attn_logits contain nans, and attn_weight contain -inf.

The dtype of these tensor is torch.float16

When I directly apply the nn.functional.scaled_dot_product_attention to on the same q_states and kv_states, the outputs is normal.

Could someone tell me how to modify my own sdpa function?