Managed to find a solution based on the solution suggested in this post. I post it here for reference.
Here the idea is quite similar, we directly retrieve the elements from the PackedSequence
instead of padding it and unpad the results, which avoids putting useless tons of zeros in memory. Assuming a set of sequences of respective lengths lengths
packed in a PackedSequence
object called packed
, the sequences can be retrieved with following function:
def unpack_sequence(packed: PackedSequence, lengths: torch.Tensor) -> List[torch.Tensor]:
sum_batch_sizes = pad(packed.batch_sizes.to(packed.data.device), (1, 0)).cumsum(dim=0)
sequences = []
for seq_idx, seq_length in zip(packed.unsorted_indices, lengths):
indices = sum_batch_sizes[:seq_length] + seq_idx
sequences.append(packed.data[indices])
return sequences
This function induces almost no additional GPU memory usage, and is slightly faster than the proposed one for big/high-dimensional sequences (which is usually what you deal with when you have memory issues)
Here is a basic script to test which one is faster for you:
from typing import List
import torch
from torch.nn.functional import pad
from torch.nn.utils.rnn import PackedSequence
def unpack_sequence(packed: PackedSequence, lengths: torch.Tensor) -> List[torch.Tensor]:
sum_batch_sizes = pad(packed.batch_sizes.to(packed.data.device), (1, 0)).cumsum(dim=0)
sequences = []
for seq_idx, seq_length in zip(packed.unsorted_indices, lengths):
indices = sum_batch_sizes[:seq_length] + seq_idx
sequences.append(packed.data[indices])
return sequences
if __name__ == "__main__":
import time
from torch.nn.utils.rnn import pack_sequence
from torch.nn.utils.rnn import unpack_sequence as naive_unpack_sequence
device = torch.device("cuda:0")
num_sequences = 32
lengths = torch.randint(1, 4096, (num_sequences,))
seqs = [torch.randint(10, (length, 2048), device=device) for length in lengths]
packed = pack_sequence(seqs, enforce_sorted=False)
t0 = time.time()
unpacked_mine = unpack_sequence(packed, lengths)
t1 = time.time()
print("mine:", t1 - t0)
t0 = time.time()
unpacked_naive = naive_unpack_sequence(packed)
t1 = time.time()
print("naive:", t1 - t0)
# print(unpacked_naive)
# print(unpacked_mine)
for a, b in zip(unpacked_naive, unpacked_mine):
torch.testing.assert_close(a, b)
print("OK.")