In several Transformer components and functions, the two arguments is_causal
and src_mask
exist. But it seems the definitions are different and confusing.
In my test with nn.MultiheadAttention
, whether is_causal
is set to be True
or False
(default), there is no difference in the output. Why? BTW the attention mask I used is not causal, but only a mask for test purpose:
In torch.nn.TransformerEncoder
:
- mask (Optional[ Tensor]) β the mask for the src sequence (optional).
- is_causal (Optional[ bool]) β If specified, applies a causal mask as mask (optional) and ignores attn_mask for computing scaled dot product attention. Default:
False
.
In torch.nn.TransformerEncoderLayer
:
- src_mask (Optional[ Tensor]) β the mask for the src sequence (optional).
- is_causal (bool) β If specified, applies a causal mask as src_mask. Default:
False
.
(1) Does the above two cases mean if is_causal
is True, a new upper triangular causal mask is created internally to replace the provided src_mask
? What is the difference if is_causal
is False and src_mask
is also provided?
In torch.nn.MultiheadAttention:
- is_causal (bool) β If specified, applies a causal mask as attention mask. Default:
False
. Warning:is_causal
provides a hint thatattn_mask
is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
In nn.functional.multi_head_attention_forward:
is_causal: If specified, applies a causal mask as attention mask, and ignores
attn_mask for computing scaled dot product attention.
Default: ``False``.
.. warning::
is_causal is provides a hint that the attn_mask is the
causal mask.Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
if is_causal and attn_mask is None:
raise RuntimeError(
"Need attn_mask if specifying the is_causal hint. "
"You may use the Transformer module method "
"`generate_square_subsequent_mask` to create this mask."
)
(2) If is_causal
is True means we can ignore the attn_mask
, why raising this error?
In torch.nn.functional.scaled_dot_product_attention
:
- attn_mask (optional Tensor) β Attention mask; shape (N,β¦,L,S). Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention. A float mask of the same type as query, key, value that is added to the attention score.
- dropout_p (float) β Dropout probability; if greater than 0.0, dropout is applied
- is_causal (bool) β If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.
(3)
βerrors if both attn_mask and is_causal are setβ - this function does not allow both attn_mask and is_causal to be set together. Are they designed to be different?