I use the equivalent python code from PyTorch official document to instead nn.functional.scaled_dot_product_attention.
def scaled_dot_product_attention_manual(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(attn_mask.shape, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias.to(query.device)
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_logits = attn_weight
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value, attn_logits
But the attn_logits contain nans, and attn_weight contain -inf.
The dtype of these tensor is torch.float16
When I directly apply the nn.functional.scaled_dot_product_attention to on the same q_states and kv_states, the outputs is normal.
Could someone tell me how to modify my own sdpa function?