I am surprised by this piece of code in functional.py
if is_causal and attn_mask is None:
raise RuntimeError(
"Need attn_mask if specifying the is_causal hint. "
"You may use the Transformer module method "
"`generate_square_subsequent_mask` to create this mask."
)
if is_causal and key_padding_mask is None and not need_weights:
# when we have a kpm or need weights, we need attn_mask
# Otherwise, we use the is_causal hint go as is_causal
# indicator to SDPA.
attn_mask = None
The first test requires an attention mask when is_causal == True. The second test says that we do not need it in some subcase (is_causal and key_padding_mask is None and not need_weights). So if you are in the subcase, you have to pass the attention mask but it will be set to None
I was expecting to not provide an attention mask if I want pytorch to create a causal mask for me.
need_weights: output attn_output_weights.
Default: `True`
Note: `needs_weight` defaults to `True`, but should be set to `False`
For best performance when attention weights are not needed.
*Setting needs_weights to `True`
leads to a significant performance degradation.*
Masking is done to ignore certain elements in the attention matrix. If the attention weights are not used at all, no need for masking.
That being said, I can’t tell case where you don’t need the weights.
PyTorch recommends setting need_weights=False (the default is True) for improved performance[source]. When setting this to need_weights=False and Keep_padding_mask is None , the subcase is executed
In this case, as noted in the comments, PyTorch uses a more efficient implementation of scaled dot product attention (SDPA) by fusing multiple operations.
This efficient implementation requires attention_mask=None (even though masking is still handled internally). Therefore, even if one passes the attention_mask manually, it gets overridden.
Why not always use this fused version, then?
Likely because fusion prevents returning the attention weights. For more details, refer to FlashAttention.