`unpack_sequence` without padding

Hi,
I have operations that I need to do in parallel on the GPU for batches of sequences.
The issue is that the sequences have very different sizes and I have limited GPU resources so I definitely cannot use padding to process my sequences per batch.
I found that PackedSequence could be helpful since data contained in these objects are smartly stored without padding, however at the end I need to unpack the object to get a list of tensors. And unpack_sequence internally uses pad_packed_sequence, whereas the reason why I used PackedSequence is to avoid padding…
Does anyone know if there is a smart way to unpack sequences without naive padding?

Thanks a lot!

Managed to find a solution based on the solution suggested in this post. I post it here for reference.

Here the idea is quite similar, we directly retrieve the elements from the PackedSequence instead of padding it and unpad the results, which avoids putting useless tons of zeros in memory. Assuming a set of sequences of respective lengths lengths packed in a PackedSequence object called packed, the sequences can be retrieved with following function:

def unpack_sequence(packed: PackedSequence, lengths: torch.Tensor) -> List[torch.Tensor]:
    sum_batch_sizes = pad(packed.batch_sizes.to(packed.data.device), (1, 0)).cumsum(dim=0)
    sequences = []
    for seq_idx, seq_length in zip(packed.unsorted_indices, lengths):
        indices = sum_batch_sizes[:seq_length] + seq_idx
        sequences.append(packed.data[indices])
    return sequences

This function induces almost no additional GPU memory usage, and is slightly faster than the proposed one for big/high-dimensional sequences (which is usually what you deal with when you have memory issues)

Here is a basic script to test which one is faster for you:

from typing import List

import torch
from torch.nn.functional import pad
from torch.nn.utils.rnn import PackedSequence


def unpack_sequence(packed: PackedSequence, lengths: torch.Tensor) -> List[torch.Tensor]:
    sum_batch_sizes = pad(packed.batch_sizes.to(packed.data.device), (1, 0)).cumsum(dim=0)
    sequences = []
    for seq_idx, seq_length in zip(packed.unsorted_indices, lengths):
        indices = sum_batch_sizes[:seq_length] + seq_idx
        sequences.append(packed.data[indices])
    return sequences


if __name__ == "__main__":
    import time

    from torch.nn.utils.rnn import pack_sequence
    from torch.nn.utils.rnn import unpack_sequence as naive_unpack_sequence

    device = torch.device("cuda:0")

    num_sequences = 32
    lengths = torch.randint(1, 4096, (num_sequences,))
    seqs = [torch.randint(10, (length, 2048), device=device) for length in lengths]

    packed = pack_sequence(seqs, enforce_sorted=False)

    t0 = time.time()
    unpacked_mine = unpack_sequence(packed, lengths)
    t1 = time.time()
    print("mine:", t1 - t0)

    t0 = time.time()
    unpacked_naive = naive_unpack_sequence(packed)
    t1 = time.time()
    print("naive:", t1 - t0)

    # print(unpacked_naive)
    # print(unpacked_mine)

    for a, b in zip(unpacked_naive, unpacked_mine):
        torch.testing.assert_close(a, b)
    print("OK.")