Build a 2d mask from sequence of lengths

hello!
I’d better explain my request with an example.
let there be a sequence of integers s = [2, 1, 3]. I want to create a mask from this sequence as follows mask = [ [1, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 1, 1] ]. each row of 2-d mask contains s[i] ones, and these ones going successive. thus, len of each row of mask is sum(s).
Moreover, I want to effectively complete a task with batched input (a batch of such sequences). ofc, rows of masks in a batch would not be equal, so I want to pad it with trailing zeros.

I have quite a weird vectorized code for the task, but I want it to be simpler. If anyone interested, I’ll provide a code in comment below.

def sequence_mask(s):
    len = s.sum().item()
    x = torch.arange(len, dtype=s.dtype, device=s.device)
    return x.unsqueeze(0) < s.unsqueeze(1)

def generate_mask_2d(s):
    B, L = s.shape
    # (B, L)
    s = torch.cumsum(s, dim=1)
    # (B * L)
    s = s.view(B * L)

    # (B * L, len)
    mask = sequence_mask(s)
    # (B, L, len)
    mask = mask.view(B, L, len)
    mask = torch.diff(mask, dim=1, prepend=torch.zeros(B, 1, len, dtype=align.dtype, device=align.device))
    return mask