Currently, I am doing this using list comprehension and pad_sequence but it seems slow.
Are there any possible ways to do this faster using some pytorch operations?
import torch
import torch.nn.utils.rnn as rnn
a = torch.tensor([[1,4,7], [2,5,8]])
mask = a>4
selected_no_pad = [a[i][mask[i]] for i in range(a.size(0))]
# [tensor([7]), tensor([5, 8])]
selected = rnn.pad_sequence(selected_no_pad, batch_first=True, padding_value=-1)
# tensor([[ 7, -1], [ 5, 8]])
Thanks in advance.