Get each sequence's last item from packed sequence

I found a way to do this which is about 100x faster and is also better in terms of memory usage. The main idea is to directly extract the indices from the PackedSequence itself instead of padding it.

Assuming you have a PackedSequence object named packed and containing sequences of respective length lengths, you can extract the last item of each sequence this way:

sum_batch_sizes = torch.cat((
    torch.zeros(2, dtype=torch.int64),
    torch.cumsum(packed.batch_sizes, 0)
))
sorted_lengths = lengths[packed.sorted_indices]
last_seq_idxs = sum_batch_sizes[sorted_lengths] + torch.arange(lengths.size(0))
last_seq_items = packed.data[last_seq_idxs]
last_seq_items = last_seq_items[packed.unsorted_indices]

For the most skeptical of you, here is a sketch of the proof:

With the notations introduced above:

  • packed.data corresponds to X
  • packed.batch_sizes corresponds to [B_0, B_1, …, B_l0-1]
  • packed.sorted_indices and packed.unsorted_indices correspond to the permutation so that sequences are sorted by decreasing length

Hence the code above.

Since there is no need to pad the packed sequence to extract the last items with that method, it is faster than the others proposed. It is also more convenient if you have limited memory usage since you don’t have to store all the zeros from the padded sequences.

Here is a test and benchmark for 10000 sequences with random lengths between 1 and 100:

Using CPU
Method 1 # TrentBrick Mar 28, 2019
Error: 0.0
10 loops, best of 3: 97.1 ms per loop

Method 2 # BramVanroy Mar 28, 2019
Error: 0.0
10 loops, best of 3: 171 ms per loop

Method 3 # mine
Error: 0.0
1000 loops, best of 3: 820 µs per loop

Using GPU
Method 1
Error: 0.0
10 loops, best of 3: 69 ms per loop

Method 2
Error: 0.0
10 loops, best of 3: 168 ms per loop

Method 3
Error: 0.0
1000 loops, best of 3: 430 µs per loop

Here is the whole script:

import torch

from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence

lengths = torch.randint(1, 100, (10000,))

sequences = [torch.randn(i, 2) for i in lengths]

ground_truth = torch.stack([seq[-1] for seq in sequences])

packed = pack_sequence(sequences, enforce_sorted=False)

def method1(packed): # TrentBrick Mar 28, 2019
    output, input_sizes = pad_packed_sequence(packed, batch_first=True)
    last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])
    last_seq_items = output[range(output.shape[0]), last_seq_idxs, :]
    return last_seq_items

def method2(packed): # BramVanroy Mar 28, 2019
    output, input_sizes = pad_packed_sequence(packed, batch_first=True)
    last_seq_items = [output[e, i-1, :].unsqueeze(0) for e, i in enumerate(input_sizes)]
    last_seq_items = torch.cat(last_seq_items, dim=0)
    return last_seq_items

def method3(packed): # mine
    sum_batch_sizes = torch.cat((
        torch.zeros(2, dtype=torch.int64),
        torch.cumsum(packed.batch_sizes, 0)
    ))
    sorted_lengths = lengths[packed.sorted_indices]
    last_seq_idxs = sum_batch_sizes[sorted_lengths] + torch.arange(lengths.size(0))
    last_seq_items = packed.data[last_seq_idxs]
    last_seq_items = last_seq_items[packed.unsorted_indices]
    return last_seq_items

print('Using CPU')
ground_truth = ground_truth.cpu()
packed = packed.cpu()
lengths = lengths.cpu()
for i, method in enumerate([method1, method2, method3]):
    print('Method', i+1)
    print('Error:', torch.norm(ground_truth - method(packed)).item())
    %timeit method(packed)
    print()

print('Using GPU')
ground_truth = ground_truth.cuda()
packed = packed.cuda()
# lengths = lengths.cuda()
for i, method in enumerate([method1, method2, method3]):
    print('Method', i+1)
    print('Error:', torch.norm(ground_truth - method(packed)).item())
    %timeit method(packed)
    print()
4 Likes