https://pytorch.org/docs/master/nn.html?highlight=multiheadattention#torch.nn.MultiheadAttention
nn.MultiheadAttention accept key_padding_mask and attn_mask
its easy to make key_padding_mask like this
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
mask = (ids < lengths.unsqueeze(1))
return ~mask
But this work only on key value like this
Or like this
How to make attn mask of size (batch_size*nheads, lens_0, lens_1) for masking both query and key values?