I want to understand what is the difference between attn_mask and is_causal in MultiheadAttedntion class, I feel both do the same thing “masking for future tokens”
J_Johnson
(J Johnson)
January 12, 2024, 3:32am
2
Welcome to the forums!
Some LLM training methods involve masking certain words in the middle of a sentence. For example:
The ___ is brown, fluffy and has sharp claws.
The above is where you might specify an attn_mask in order to mask words anywhere in the sentence.
While other training methods, such as next token prediction, involve masking causally, that is in sequence order(think causality). That sends the shape of an upper triangular matrix of zeros and 1s every where else, including the diagonal.
Setting the is_causal = True tells PyTorch to optimize for causal attention. A different algorithm gets used in that case. However, see the link at the bottom for why you cannot pass in attn_mask = None and is_causal = True.
Additionally, non-NLP uses of Transformers may involve masking casually or non-causally.
You can read more on the topic in this thread:
opened 03:09AM - 17 Apr 23 UTC
closed 06:00PM - 20 Apr 23 UTC
oncall: transformer/mha
### 🐛 Describe the bug
When `need_weight=True`, `is_causal` is ignored in Multi… headAttention.forward and the result without causal masking is returned.
```python
import torch
import torch.nn as nn
batch_size = 4
seq_len = 3
embedding_dim = 8
num_heads = 2
mha=nn.MultiheadAttention(num_heads=num_heads, embed_dim=embedding_dim, batch_first=True)
x = torch.randn(batch_size, seq_len, embedding_dim)
mask = nn.Transformer.generate_square_subsequent_mask(seq_len)
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
no_mask = mha(x,x,x, need_weights=False)[0]
with_attn_mask = mha(x,x,x, need_weights=True, attn_mask=mask)[0]
with_is_causal_need_weights = mha(x,x,x, need_weights=True, is_causal=True)[0]
with_is_causal_no_need_weights = mha(x,x,x, need_weights=False, is_causal=True)[0]
#succeeds
assert with_attn_mask.allclose(with_is_causal_no_need_weights)
#both should succeed but fail
assert with_attn_mask.allclose(with_is_causal_need_weights), "is_causal should match regardless of 'need_weights'"
assert not no_mask.allclose(with_is_causal_need_weights), "no mask should NOT match is_causal=True, need_weights=True"
```
### Versions
```
PyTorch version: 2.0.0
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 8.3.0-6) 8.3.0
Clang version: Could not collect
CMake version: version 3.13.4
Libc version: glibc-2.28
Python version: 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-4.19.0-23-cloud-amd64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: 11.3.109
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla T4
GPU 1: Tesla T4
Nvidia driver version: 510.47.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 4
On-line CPU(s) list: 0-3
Thread(s) per core: 2
Core(s) per socket: 2
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 63
Model name: Intel(R) Xeon(R) CPU @ 2.30GHz
Stepping: 0
CPU MHz: 2299.998
BogoMIPS: 4599.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 46080K
NUMA node0 CPU(s): 0-3
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==2.0.1
[pip3] torch==2.0.0
[pip3] torcharrow==0.1.0
[pip3] torchdata==0.6.0+a6c4904
[pip3] torchmetrics==0.11.4
[conda] dlenv-pytorch-1-13-gpu 1.0.20230310 py37h003b471_0 file:///tmp/conda-pkgs
[conda] numpy 1.21.6 pypi_0 pypi
[conda] torch 1.13.1 pypi_0 pypi
[conda] torch-xla 1.13 pypi_0 pypi
[conda] torchvision 0.14.1 pypi_0 pypi
```
cc @jbschlosser @bhosmer @cpuhrsch @erichan1
1 Like
thank you, your explanation is clear