Using Sparsemax along with ke_padding mask for a batch_size >1

Hi,

I have been recently toying with Sparsemax instead of Softmax to get sparsity in the attention matrix.

What I have noticed is that with a batch_size of 1, I get sparsity in the attention matrix (For a given query, some of the keys are set to 0). There is no key padding mask used in this case.
However when I use a batch_size>1 and provide a key_padding mask, the attention matrix has zeros only in the masked position(For a given query, only the keys in the masked position are set to 0). I am not getting sparsity among the keys that actually matter (the unmasked positions).
I think based on how sparsemax works (sparsemax ยท PyPI) this is probably expected but then this might prevent from using a batch_size>1 .

Has anyone come across this issue?

Thanks!

Below my MAB class

Blockquote

class MAB(nn.Module):
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
    super(MAB, self).__init__()
    self.dim_V = dim_V
    self.num_heads = num_heads
    self.fc_q = nn.Linear(dim_Q, dim_V)
    self.fc_k = nn.Linear(dim_K, dim_V)
    self.fc_v = nn.Linear(dim_K, dim_V)
    if ln:
        self.ln0 = nn.LayerNorm(dim_V)
        self.ln1 = nn.LayerNorm(dim_V)
    self.fc_o = nn.Linear(dim_V, dim_V)

def forward(self, Q, K, key_mask=None, attn_mask=None):
    Q = self.fc_q(Q)
    K, V = self.fc_k(K), self.fc_v(K)

    dim_split = self.dim_V // self.num_heads
    Q_ = torch.cat(Q.split(dim_split, 2), 0)
    K_ = torch.cat(K.split(dim_split, 2), 0)
    V_ = torch.cat(V.split(dim_split, 2), 0)

    A = Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V)

    if key_mask is not None:
        attn_mask = key_mask
        A = A.masked_fill(attn_mask == 1, -9e15)

    sparsemax = Sparsemax(dim=-1)
    A = sparsemax(A)


    O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
    O = O if getattr(self, "ln0", None) is None else self.ln0(O)
    O = O + F.relu(self.fc_o(O))
    O = O if getattr(self, "ln1", None) is None else self.ln1(O)

    return (O, A)