Use attn_mask to enforce causality in MultiheadAttention?

Hi Team,

Following https://github.com/pytorch/pytorch/issues/21518, for tracking the additive attn_mask versus multiplicative attn_mask issue, I tested by the following code:


attn_mask = torch.tril(torch.ones((8,8)))
inputs = torch.ones((8, 2, 6))
mha = torch.nn.MultiheadAttention(6, 2) # hidden_dim=6, head_num=2
outputs, weights = mha(inputs, inputs, inputs, attn_mask=attn_mask) # Q, K, V, attn_mask for causality

to enforce causality, but the returned attention weights suggest it still attend to future inputs:

 tensor([[[0.2797, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029],
          [0.2377, 0.2377, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874],
          [0.2066, 0.2066, 0.2066, 0.0760, 0.0760, 0.0760, 0.0760, 0.0760],
          [0.1828, 0.1828, 0.1828, 0.1828, 0.0672, 0.0672, 0.0672, 0.0672],
          [0.1638, 0.1638, 0.1638, 0.1638, 0.1638, 0.0603, 0.0603, 0.0603],
          [0.1485, 0.1485, 0.1485, 0.1485, 0.1485, 0.1485, 0.0546, 0.0546],
          [0.1357, 0.1357, 0.1357, 0.1357, 0.1357, 0.1357, 0.1357, 0.0499],
          [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]],
 
         [[0.2797, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029],
          [0.2377, 0.2377, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874],
          [0.2066, 0.2066, 0.2066, 0.0760, 0.0760, 0.0760, 0.0760, 0.0760],
          [0.1828, 0.1828, 0.1828, 0.1828, 0.0672, 0.0672, 0.0672, 0.0672],
          [0.1638, 0.1638, 0.1638, 0.1638, 0.1638, 0.0603, 0.0603, 0.0603],
          [0.1485, 0.1485, 0.1485, 0.1485, 0.1485, 0.1485, 0.0546, 0.0546],
          [0.1357, 0.1357, 0.1357, 0.1357, 0.1357, 0.1357, 0.1357, 0.0499],
          [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]]],
        grad_fn=<DivBackward0>))

does that means its still additive mask in current implementation(I used PyTorch 1.6.0+cu101 on google colab)?

THX!

Resolved, pls check https://github.com/pytorch/pytorch/issues/21518, guess it’s still additive attention for numerical stability under the hood but behaves well(by replacing 0 as -inf before applying softmax for QK generated attention outputs).