Hi All,
Trying to understand why the mask does not seem to work. There is a similar question here that seems to resolve the issue. But same code below doesn’t return the 0s in upper triangle of the attention matrix. Torch version is 1.13.1.
inputs = torch.ones((4,3))
maten = nn.MultiheadAttention(embed_dim=3, num_heads=1)
attn_mask = torch.tril(torch.ones((4,4)))
attn_mask = attn_mask>0
query2,_ = maten(inputs, inputs, inputs, attn_mask=~attn_mask)
query2
tensor([[-0.0128, 0.0919, 0.1716],
[-0.0128, 0.0919, 0.1716],
[-0.0128, 0.0919, 0.1716],
[-0.0128, 0.0919, 0.1716]], grad_fn=<SqueezeBackward1>)```