Managed to find a solution based on the solution suggested in this post. I post it here for reference.

Here the idea is quite similar, we directly retrieve the elements from the `PackedSequence`

instead of padding it and unpad the results, which avoids putting useless tons of zeros in memory. Assuming a set of sequences of respective lengths `lengths`

packed in a `PackedSequence`

object called `packed`

, the sequences can be retrieved with following function:

```
def unpack_sequence(packed: PackedSequence, lengths: torch.Tensor) -> List[torch.Tensor]:
sum_batch_sizes = pad(packed.batch_sizes.to(packed.data.device), (1, 0)).cumsum(dim=0)
sequences = []
for seq_idx, seq_length in zip(packed.unsorted_indices, lengths):
indices = sum_batch_sizes[:seq_length] + seq_idx
sequences.append(packed.data[indices])
return sequences
```

This function induces almost no additional GPU memory usage, and is slightly faster than the proposed one for big/high-dimensional sequences (which is usually what you deal with when you have memory issues)

Here is a basic script to test which one is faster for you:

```
from typing import List
import torch
from torch.nn.functional import pad
from torch.nn.utils.rnn import PackedSequence
def unpack_sequence(packed: PackedSequence, lengths: torch.Tensor) -> List[torch.Tensor]:
sum_batch_sizes = pad(packed.batch_sizes.to(packed.data.device), (1, 0)).cumsum(dim=0)
sequences = []
for seq_idx, seq_length in zip(packed.unsorted_indices, lengths):
indices = sum_batch_sizes[:seq_length] + seq_idx
sequences.append(packed.data[indices])
return sequences
if __name__ == "__main__":
import time
from torch.nn.utils.rnn import pack_sequence
from torch.nn.utils.rnn import unpack_sequence as naive_unpack_sequence
device = torch.device("cuda:0")
num_sequences = 32
lengths = torch.randint(1, 4096, (num_sequences,))
seqs = [torch.randint(10, (length, 2048), device=device) for length in lengths]
packed = pack_sequence(seqs, enforce_sorted=False)
t0 = time.time()
unpacked_mine = unpack_sequence(packed, lengths)
t1 = time.time()
print("mine:", t1 - t0)
t0 = time.time()
unpacked_naive = naive_unpack_sequence(packed)
t1 = time.time()
print("naive:", t1 - t0)
# print(unpacked_naive)
# print(unpacked_mine)
for a, b in zip(unpacked_naive, unpacked_mine):
torch.testing.assert_close(a, b)
print("OK.")
```