## 🐛 Bug
Using key_padding_mask and attn_mask with nn.MultiheadAttention caus…es gradients to become NaN under some use cases.
## To Reproduce
Steps to reproduce the behavior:
Backwards pass through nn.MultiheadAttention layer where the forward pass used:
1. attn_mask limiting context in both directions (e.g. bucketed attention)
2. key_padding_mask where there is padding for at least one sequence (and there is also at least one valid entry for every sequence, as expected)
3. The dimensions that were masked are not used to calculate the loss
4. The loss is a real number (not NaN)
```
import torch
torch.manual_seed(0)
'''Create attention layer'''
attn = torch.nn.MultiheadAttention(embed_dim=1, num_heads=1)
'''Create dummy input'''
x = torch.rand(3, 2, 1)
'''Padding mask, second sequence can only see first embedding'''
key_padding_mask = torch.as_tensor([[False, False, False], [False, True, True]], dtype=torch.bool)
'''Attention mask, bucketing attention to current and previous time steps'''
attn_mask = torch.as_tensor([[0., float('-inf'), float('-inf')], [0., 0., float('-inf')], [float('-inf'), 0., 0.]])
'''Generate attention embedding'''
output, scores = attn(x, x, x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
print("scores")
print(scores)
'''Create a dummy loss, only use the first embedding which is defined for all sequences'''
loss = output[0, :].sum()
print("loss")
print(loss)
'''Backwards pass and gradients'''
loss.backward()
print("grads")
for n, p in attn.named_parameters():
print(n, p.grad)
> scores
> tensor([[[1.0000, 0.0000, 0.0000],
> [0.4468, 0.5532, 0.0000],
> [0.0000, 0.5379, 0.4621]],
> [[1.0000, 0.0000, 0.0000],
> [1.0000, 0.0000, 0.0000],
> [ nan, nan, nan]]], grad_fn=<DivBackward0>)
> loss
> tensor(0.0040, grad_fn=<SumBackward0>)
> grads
> in_proj_weight tensor([[nan],
> [nan],
> [nan]])
> in_proj_bias tensor([nan, nan, nan])
> out_proj.weight tensor([[nan]])
> out_proj.bias tensor([2.])
```
## Expected behavior
Gradients should not be NaN
## Environment
PyTorch version: 1.5.1
Is debug build: No
CUDA used to build PyTorch: None
OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.1
[conda] blas 1.0 mkl
[conda] cpuonly 1.0 0 pytorch
[conda] mkl 2020.1 217
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.1.0 py37h23d657b_0
[conda] mkl_random 1.1.1 py37h0573a6f_0
[conda] numpy 1.18.5 py37ha1c710e_0
[conda] numpy-base 1.18.5 py37hde5b4d6_0
[conda] pytorch 1.5.1 py3.7_cpu_0 [cpuonly] pytorch
Also fails when using GPU.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry @zhangguanheng66