How to do row-wise masking as batch operation?

My question is the way of row-wise masking in batch manner.
I wrote non-batch case in the following.

import torch

n_dim0, n_dim1 = 4, 100
threshold = 0.5

storage = torch.rand((n_dim0, n_dim1))
mask = (threshold < storage)

print('storage:', storage.shape)  # [4, 100]
print('mask:   ', mask.shape)     # [4, 100]

mask_lens = torch.sum(mask, axis=1)   # [4]
max_mask_lens = torch.max(mask_lens)  # [1]

masked = torch.zeros((n_dim0, max_mask_lens))
masked[0, 0:mask_lens[0]] = storage[0, mask[0]]  #
masked[1, 0:mask_lens[1]] = storage[1, mask[1]]  # How to make these lines
masked[2, 0:mask_lens[2]] = storage[2, mask[2]]  # batch operation?
masked[3, 0:mask_lens[3]] = storage[3, mask[3]]  #

print('masked: ', masked.shape)  # [4, X]  (X is around 50)

I wish someone has great idea.
Thank you.

One tricky idea has come up. However this is not smart and not effective.

import torch

n_dim0, n_dim1 = 4, 100
threshold = 0.5

storage = torch.rand((n_dim0, n_dim1))
mask = (threshold < storage)

print('storage:', storage.shape)  # [4, 100]
print('mask:   ', mask.shape)     # [4, 100]

mask_lens = torch.sum(mask, axis=1)   # [4]
max_mask_lens = torch.max(mask_lens)  # [1]

# Create mask indices
mask_idxs = torch.nonzero(mask)  # [X, 2]  (2 == storage.ndim)

# Append padding to `mask_idxs`
pad_sizes = max_mask_lens - mask_lens
pads = []
for row_idx, pad_size in enumerate(pad_sizes):
    pad = torch.zeros((pad_size, storage.ndim), dtype=torch.long)
    pad[:, 0] = row_idx
    pads.append(pad)
mask_idxs = torch.cat([mask_idxs] + pads, axis=0)

# Sort `mask_idxs` by first element
mask_idxs_argssort = torch.argsort(mask_idxs[:, 0])
mask_idxs = mask_idxs[mask_idxs_argssort]

# Masking
masked = storage[mask_idxs[:, 0], mask_idxs[:, 1]]

# Restore original dimension
masked = masked.reshape(n_dim0, -1)

print('masked: ', masked.shape)  # [4, X]  (X is around 50)