How to perform repeat padding for variable length data?

Well, i wanted to mask attentions also by query axis.
But with default attn_mask setup it case nans.
Google say, its because only -infs in axis.

Now i edited source code of multi head attention forward like this:

if attn_mask is not None:
    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
    attn_output_weights = attn_output_weights.masked_fill(attn_mask, 1e-9)
    attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)

attn_output_weights = softmax(
    attn_output_weights, dim=-1)

And made masks like this

def get_mask_from_lengths_3d(batch_size, lengths_query, lengths_key, nheads):
    mask = torch.zeros(batch_size, lengths_key.max(),
                    lengths_query.max()).cuda()

    max_len = torch.max(lengths_key).item()
    ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
    mask[ids > lengths_key.unsqueeze(1) - 1] = 1

    mask = mask.transpose(1, 2)

    max_len = torch.max(lengths_query).item()
    ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
    mask[ids > lengths_query.unsqueeze(1) - 1] = 1

    return mask.unsqueeze(1).repeat(1, nheads, 1, 1).bool()


def generate_square_subsequent_mask_3d(batch_size, lengths_query, nheads):
    sz = lengths_query.max().item()
    mask = torch.triu(torch.ones(sz, sz), 1).cuda(
    ).unsqueeze(0).repeat(batch_size, 1, 1)

    ids = torch.arange(0, sz, out=torch.cuda.LongTensor(sz))
    mask[ids > lengths_query.unsqueeze(1) - 1] = 1

    return mask.unsqueeze(1).repeat(1, nheads, 1, 1).bool()

Alignment for one layer seems to be right

With mask value float(-inf) it became nan immediately.