Causal masking hint

I am surprised by this piece of code in functional.py

    if is_causal and attn_mask is None:
        raise RuntimeError(
            "Need attn_mask if specifying the is_causal hint. "
            "You may use the Transformer module method "
            "`generate_square_subsequent_mask` to create this mask."
        )

    if is_causal and key_padding_mask is None and not need_weights:
        # when we have a kpm or need weights, we need attn_mask
        # Otherwise, we use the is_causal hint go as is_causal
        # indicator to SDPA.
        attn_mask = None

The first test requires an attention mask when is_causal == True. The second test says that we do not need it in some subcase (is_causal and key_padding_mask is None and not need_weights). So if you are in the subcase, you have to pass the attention mask but it will be set to None :exploding_head:

I was expecting to not provide an attention mask if I want pytorch to create a causal mask for me.

Is it a bug or I am missing something?

Thank you!

According to the code comment:

 need_weights: output attn_output_weights.
            Default: `True`
            Note: `needs_weight` defaults to `True`, but should be set to `False`
            For best performance when attention weights are not needed.
            *Setting needs_weights to `True`
            leads to a significant performance degradation.*

Masking is done to ignore certain elements in the attention matrix. If the attention weights are not used at all, no need for masking.

That being said, I can’t tell case where you don’t need the weights.

1 Like

That’s a good observation.

PyTorch recommends setting need_weights=False (the default is True) for improved performance[source]. When setting this to need_weights=False and Keep_padding_mask is None , the subcase is executed

In this case, as noted in the comments, PyTorch uses a more efficient implementation of scaled dot product attention (SDPA) by fusing multiple operations.

This efficient implementation requires attention_mask=None (even though masking is still handled internally). Therefore, even if one passes the attention_mask manually, it gets overridden.

Why not always use this fused version, then?
Likely because fusion prevents returning the attention weights. For more details, refer to FlashAttention.

Hope it helps!

1 Like