Well, i wanted to mask attentions also by query axis.
But with default attn_mask setup it case nans.
Google say, its because only -infs in axis.
Now i edited source code of multi head attention forward like this:
if attn_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(attn_mask, 1e-9)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
attn_output_weights = softmax(
attn_output_weights, dim=-1)
And made masks like this
def get_mask_from_lengths_3d(batch_size, lengths_query, lengths_key, nheads):
mask = torch.zeros(batch_size, lengths_key.max(),
lengths_query.max()).cuda()
max_len = torch.max(lengths_key).item()
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
mask[ids > lengths_key.unsqueeze(1) - 1] = 1
mask = mask.transpose(1, 2)
max_len = torch.max(lengths_query).item()
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
mask[ids > lengths_query.unsqueeze(1) - 1] = 1
return mask.unsqueeze(1).repeat(1, nheads, 1, 1).bool()
def generate_square_subsequent_mask_3d(batch_size, lengths_query, nheads):
sz = lengths_query.max().item()
mask = torch.triu(torch.ones(sz, sz), 1).cuda(
).unsqueeze(0).repeat(batch_size, 1, 1)
ids = torch.arange(0, sz, out=torch.cuda.LongTensor(sz))
mask[ids > lengths_query.unsqueeze(1) - 1] = 1
return mask.unsqueeze(1).repeat(1, nheads, 1, 1).bool()
Alignment for one layer seems to be right
With mask value float(-inf) it became nan immediately.