Get each sequence's last item from packed sequence

[x-post from Stack Overflow]

I am trying to put a packed and padded sequence through a GRU, and retrieve the output of the last item of each sequence. Of course I don’t mean the -1 item, but the actual last, not-padded item. We know the lengths of the sequences in advance, so it should be as easy as to extract for each sequence the length-1 item.

I tried the following

import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Data
input = torch.Tensor([[[0., 0., 0.],
                       [1., 0., 1.],
                       [1., 1., 0.],
                       [1., 0., 1.],
                       [1., 0., 1.],
                       [1., 1., 0.]],
                      
                      [[1., 1., 0.],
                       [0., 1., 0.],
                       [0., 0., 0.],
                       [0., 1., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]],

                      [[0., 0., 0.],
                       [1., 0., 0.],
                       [1., 1., 1.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]],

                      [[1., 1., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]]])

lengths = [6, 4, 3, 1]
p = pack_padded_sequence(input, lengths, batch_first=True)

# Forward
gru = torch.nn.GRU(3, 12, batch_first=True)
packed_output, gru_h = gru(p)

# Unpack
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)

last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])

last_seq_items = torch.index_select(output, 1, last_seq_idxs) 

print(last_seq_items.size())
# torch.Size([4, 4, 12])

But the shape is not what I expect. I had expected to get 4x12, i.e. last item of each individual sequence x hidden.`

I could loop through the whole thing, and build a new tensor containing the items I need, but I was hoping for a built-in approach that took advantage of some smart math. I fear that manually looping and building, will result in very poor performance.

No idea if this is the most efficient method but I think it gets the job done.

Change your line:

last_seq_items = torch.index_select(output, 1, last_seq_idxs) 

to this:

last_seq_items = output[range(output.shape[0]), last_seq_idxs, :]

By using range(), we are able to give it a different index to take at each 0th dimension.

2 Likes

@TrentBrick Thank you for that. That makes sense, I think.

Personally, I came up with this:

output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# One per sequence, with its last actual node extracted, and unsqueezed
last_seq = [output[e, i-1, :].unsqueeze(0) for e, i in enumerate(input_sizes)]
# Merge them together
last_seq = torch.cat(last_seq, dim=0) 
1 Like

@BramVanroy I think mine might be faster because it is for loop free! And glad I could help :slight_smile:

1 Like

Someone suggested using torch’s arange. Not sure whether this would be faster than Python’s native range, but I am posting it for completeness sake.

1 Like

That is probably right!

I found a way to do this which is about 100x faster and is also better in terms of memory usage. The main idea is to directly extract the indices from the PackedSequence itself instead of padding it.

Assuming you have a PackedSequence object named packed and containing sequences of respective length lengths, you can extract the last item of each sequence this way:

sum_batch_sizes = torch.cat((
    torch.zeros(2, dtype=torch.int64),
    torch.cumsum(packed.batch_sizes, 0)
))
sorted_lengths = lengths[packed.sorted_indices]
last_seq_idxs = sum_batch_sizes[sorted_lengths] + torch.arange(lengths.size(0))
last_seq_items = packed.data[last_seq_idxs]
last_seq_items = last_seq_items[packed.unsorted_indices]

For the most skeptical of you, here is a sketch of the proof:

With the notations introduced above:

  • packed.data corresponds to X
  • packed.batch_sizes corresponds to [B_0, B_1, …, B_l0-1]
  • packed.sorted_indices and packed.unsorted_indices correspond to the permutation so that sequences are sorted by decreasing length

Hence the code above.

Since there is no need to pad the packed sequence to extract the last items with that method, it is faster than the others proposed. It is also more convenient if you have limited memory usage since you don’t have to store all the zeros from the padded sequences.

Here is a test and benchmark for 10000 sequences with random lengths between 1 and 100:

Using CPU
Method 1 # TrentBrick Mar 28, 2019
Error: 0.0
10 loops, best of 3: 97.1 ms per loop

Method 2 # BramVanroy Mar 28, 2019
Error: 0.0
10 loops, best of 3: 171 ms per loop

Method 3 # mine
Error: 0.0
1000 loops, best of 3: 820 µs per loop

Using GPU
Method 1
Error: 0.0
10 loops, best of 3: 69 ms per loop

Method 2
Error: 0.0
10 loops, best of 3: 168 ms per loop

Method 3
Error: 0.0
1000 loops, best of 3: 430 µs per loop

Here is the whole script:

import torch

from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence

lengths = torch.randint(1, 100, (10000,))

sequences = [torch.randn(i, 2) for i in lengths]

ground_truth = torch.stack([seq[-1] for seq in sequences])

packed = pack_sequence(sequences, enforce_sorted=False)

def method1(packed): # TrentBrick Mar 28, 2019
    output, input_sizes = pad_packed_sequence(packed, batch_first=True)
    last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])
    last_seq_items = output[range(output.shape[0]), last_seq_idxs, :]
    return last_seq_items

def method2(packed): # BramVanroy Mar 28, 2019
    output, input_sizes = pad_packed_sequence(packed, batch_first=True)
    last_seq_items = [output[e, i-1, :].unsqueeze(0) for e, i in enumerate(input_sizes)]
    last_seq_items = torch.cat(last_seq_items, dim=0)
    return last_seq_items

def method3(packed): # mine
    sum_batch_sizes = torch.cat((
        torch.zeros(2, dtype=torch.int64),
        torch.cumsum(packed.batch_sizes, 0)
    ))
    sorted_lengths = lengths[packed.sorted_indices]
    last_seq_idxs = sum_batch_sizes[sorted_lengths] + torch.arange(lengths.size(0))
    last_seq_items = packed.data[last_seq_idxs]
    last_seq_items = last_seq_items[packed.unsorted_indices]
    return last_seq_items

print('Using CPU')
ground_truth = ground_truth.cpu()
packed = packed.cpu()
lengths = lengths.cpu()
for i, method in enumerate([method1, method2, method3]):
    print('Method', i+1)
    print('Error:', torch.norm(ground_truth - method(packed)).item())
    %timeit method(packed)
    print()

print('Using GPU')
ground_truth = ground_truth.cuda()
packed = packed.cuda()
# lengths = lengths.cuda()
for i, method in enumerate([method1, method2, method3]):
    print('Method', i+1)
    print('Error:', torch.norm(ground_truth - method(packed)).item())
    %timeit method(packed)
    print()
4 Likes

@aRI0U Your method is greatly helpful. And I just made a script based on your description, with an additional function to fetch lengths

from typing import Tuple

import torch
from torch import Tensor
from torch import jit
from torch.nn.utils.rnn import pack_sequence, PackedSequence
from torch.nn.utils.rnn import pad_packed_sequence


@jit.script
def sorted_lengths(pack: PackedSequence) -> Tuple[Tensor, Tensor]:
    indices = torch.arange(
        pack.batch_sizes[0],
        dtype=pack.batch_sizes.dtype,
        device=pack.batch_sizes.device,
    )
    lengths = ((indices + 1)[:, None] <= pack.batch_sizes[None, :]).long().sum(dim=1)
    return lengths, indices


@jit.script
def sorted_first_indices(pack: PackedSequence) -> Tensor:
    return torch.arange(
        pack.batch_sizes[0],
        dtype=pack.batch_sizes.dtype,
        device=pack.batch_sizes.device,
    )


@jit.script
def sorted_last_indices(pack: PackedSequence) -> Tensor:
    lengths, indices = sorted_lengths(pack)
    cum_batch_sizes = torch.cat([
        pack.batch_sizes.new_zeros((2,)),
        torch.cumsum(pack.batch_sizes, dim=0),
    ], dim=0)
    return cum_batch_sizes[lengths] + indices


@jit.script
def first_items(pack: PackedSequence, unsort: bool) -> Tensor:
    if unsort and pack.unsorted_indices is not None:
        return pack.data[pack.unsorted_indices]
    else:
        return pack.data[:pack.batch_sizes[0]]


@jit.script
def last_items(pack: PackedSequence, unsort: bool) -> Tensor:
    indices = sorted_last_indices(pack=pack)
    if unsort and pack.unsorted_indices is not None:
        indices = indices[pack.unsorted_indices]
    return pack.data[indices]


if __name__ == '__main__':
    x = pack_sequence([
        torch.randperm(5) + 1,
        torch.randperm(2) + 1,
        torch.randperm(3) + 1,
    ], enforce_sorted=False)
    z, _ = pad_packed_sequence(x, batch_first=True)
    print(z)

    print(first_items(pack=x, unsort=True))
    print(last_items(pack=x, unsort=True))
    print(first_items(pack=x, unsort=False))
    print(last_items(pack=x, unsort=False))

# tensor([[1, 5, 3, 2, 4],
#         [2, 1, 0, 0, 0],
#         [1, 2, 3, 0, 0]])
# tensor([1, 2, 1])
# tensor([4, 1, 3])
# tensor([1, 1, 2])
# tensor([4, 3, 1])
3 Likes