Hi,
I have a tensor containing padded sequences, and another one containing the lengths of all those sequences. What is the most efficient way to retrieve a list of those sequences?
For example, assuming I have:
seq = tensor([
[3, 2, 5, 0, 0],
[2, 1, 8, 5, 1],
[9, 0, 0, 0, 0]
])
lengths = tensor([3, 5, 1])
what is the fastest way to return [tensor([3,2,5]), tensor([2,1,8,5,1]), tensor([9])]
?
For now, the less bad solution I found is to do:
slices = torch.stack((lengths, lengths.max()-lengths), dim=1).view(-1)
return torch.split(seq.view(-1), slices.tolist())[0::2]
but it is a really ugly (and probably slow) solution
Thanks,
Alain