Masking in torch.nn.MultiheadAttention

Hi there! I am using the nn.MultiheadAttention to construct a transformer encoder layer. Suppose my queries, keys, and values, are all the same (e.g. h), meaning I call it via h, score = MHA(h, h, h). This means that for some of the computation, there is some form of self-attention going on. Is there a way to mask this away?

I think the torch.nn.MultiheadAttention has a mask argument

forward (query , key , value , key_padding_mask=None , need_weights=True , attn_mask=None )

You can watch this video for an explanation https://www.youtube.com/watch?v=_iDanMWVj98&ab_channel=PytorchModulesExplained