Relation between `is_causal` and `src_mask`?

In several Transformer components and functions, the two arguments is_causal and src_mask exist. But it seems the definitions are different and confusing.
In my test with nn.MultiheadAttention, whether is_causal is set to be True or False (default), there is no difference in the output. Why? BTW the attention mask I used is not causal, but only a mask for test purpose:

In torch.nn.TransformerEncoder:

  • mask (Optional[ Tensor]) – the mask for the src sequence (optional).
  • is_causal (Optional[ bool]) – If specified, applies a causal mask as mask (optional) and ignores attn_mask for computing scaled dot product attention. Default: False.

In torch.nn.TransformerEncoderLayer:

  • src_mask (Optional[ Tensor]) – the mask for the src sequence (optional).
  • is_causal (bool) – If specified, applies a causal mask as src_mask. Default: False.

(1) Does the above two cases mean if is_causal is True, a new upper triangular causal mask is created internally to replace the provided src_mask? What is the difference if is_causal is False and src_mask is also provided?

In torch.nn.MultiheadAttention:

  • is_causal (bool) – If specified, applies a causal mask as attention mask. Default: False. Warning: is_causal provides a hint that attn_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.

In nn.functional.multi_head_attention_forward:

    is_causal: If specified, applies a causal mask as attention mask, and ignores
        attn_mask for computing scaled dot product attention.
        Default: ``False``.
        .. warning::
            is_causal is provides a hint that the attn_mask is the
            causal mask.Providing incorrect hints can result in
            incorrect execution, including forward and backward
            compatibility.
    if is_causal and attn_mask is None:
        raise RuntimeError(
            "Need attn_mask if specifying the is_causal hint. "
            "You may use the Transformer module method "
            "`generate_square_subsequent_mask` to create this mask."
        )

(2) If is_causal is True means we can ignore the attn_mask, why raising this error?

In torch.nn.functional.scaled_dot_product_attention:

  • attn_mask (optional Tensor) – Attention mask; shape (N,…,L,S). Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention. A float mask of the same type as query, key, value that is added to the attention score.
  • dropout_p (float) – Dropout probability; if greater than 0.0, dropout is applied
  • is_causal (bool) – If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.

(3)
β€œerrors if both attn_mask and is_causal are set” - this function does not allow both attn_mask and is_causal to be set together. Are they designed to be different?

2 Likes

What’s the purpose of the is_causal flag?