Masks in transformer

I want to understand what is the difference between attn_mask and is_causal in MultiheadAttedntion class, I feel both do the same thing “masking for future tokens”

Welcome to the forums!

Some LLM training methods involve masking certain words in the middle of a sentence. For example:

The ___ is brown, fluffy and has sharp claws.

The above is where you might specify an attn_mask in order to mask words anywhere in the sentence.

While other training methods, such as next token prediction, involve masking causally, that is in sequence order(think causality). That sends the shape of an upper triangular matrix of zeros and 1s every where else, including the diagonal.

Setting the is_causal = True tells PyTorch to optimize for causal attention. A different algorithm gets used in that case. However, see the link at the bottom for why you cannot pass in attn_mask = None and is_causal = True.

Additionally, non-NLP uses of Transformers may involve masking casually or non-causally.

You can read more on the topic in this thread:

1 Like

thank you, your explanation is clear