Hi,
I have been recently toying with Sparsemax instead of Softmax to get sparsity in the attention matrix.
What I have noticed is that with a batch_size of 1, I get sparsity in the attention matrix (For a given query, some of the keys are set to 0). There is no key padding mask used in this case.
However when I use a batch_size>1 and provide a key_padding mask, the attention matrix has zeros only in the masked position(For a given query, only the keys in the masked position are set to 0). I am not getting sparsity among the keys that actually matter (the unmasked positions).
I think based on how sparsemax works (sparsemax ยท PyPI) this is probably expected but then this might prevent from using a batch_size>1 .
Has anyone come across this issue?
Thanks!
Below my MAB class
Blockquote
class MAB(nn.Module):
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K, key_mask=None, attn_mask=None):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
V_ = torch.cat(V.split(dim_split, 2), 0)
A = Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V)
if key_mask is not None:
attn_mask = key_mask
A = A.masked_fill(attn_mask == 1, -9e15)
sparsemax = Sparsemax(dim=-1)
A = sparsemax(A)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, "ln0", None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, "ln1", None) is None else self.ln1(O)
return (O, A)