Masking queries in cross-attention?

Hi! I am using nn.MultiheadAttention to do cross-attention and I need to mask both the context and the x (aka both queries and keys). However, it looks like currently, the implementation only has key_padding_mask. How can I do this for those queries (despite directly building the 2D attn_mask)? Thanks!