Select variable number of values from a batched sequence and make it a padded sequence

Currently, I am doing this using list comprehension and pad_sequence but it seems slow.
Are there any possible ways to do this faster using some pytorch operations?

import torch
import torch.nn.utils.rnn as rnn

a = torch.tensor([[1,4,7], [2,5,8]])
mask = a>4

selected_no_pad = [a[i][mask[i]] for i in range(a.size(0))]
# [tensor([7]), tensor([5, 8])]

selected = rnn.pad_sequence(selected_no_pad, batch_first=True, padding_value=-1)
# tensor([[ 7, -1], [ 5,  8]])

Thanks in advance.

This almost solved my problem.