Hi! I am using nn.MultiheadAttention
to do cross-attention and I need to mask both the context and the x (aka both queries and keys). However, it looks like currently, the implementation only has key_padding_mask
. How can I do this for those queries (despite directly building the 2D attn_mask
)? Thanks!