I found a way to do this which is about 100x faster and is also better in terms of memory usage. The main idea is to directly extract the indices from the PackedSequence
itself instead of padding it.
Assuming you have a PackedSequence
object named packed
and containing sequences of respective length lengths
, you can extract the last item of each sequence this way:
sum_batch_sizes = torch.cat((
torch.zeros(2, dtype=torch.int64),
torch.cumsum(packed.batch_sizes, 0)
))
sorted_lengths = lengths[packed.sorted_indices]
last_seq_idxs = sum_batch_sizes[sorted_lengths] + torch.arange(lengths.size(0))
last_seq_items = packed.data[last_seq_idxs]
last_seq_items = last_seq_items[packed.unsorted_indices]
For the most skeptical of you, here is a sketch of the proof:
With the notations introduced above:
packed.data
corresponds to Xpacked.batch_sizes
corresponds to [B_0, B_1, …, B_l0-1]packed.sorted_indices
andpacked.unsorted_indices
correspond to the permutation so that sequences are sorted by decreasing length
Hence the code above.
Since there is no need to pad the packed sequence to extract the last items with that method, it is faster than the others proposed. It is also more convenient if you have limited memory usage since you don’t have to store all the zeros from the padded sequences.
Here is a test and benchmark for 10000 sequences with random lengths between 1 and 100:
Using CPU
Method 1 # TrentBrick Mar 28, 2019
Error: 0.0
10 loops, best of 3: 97.1 ms per loop
Method 2 # BramVanroy Mar 28, 2019
Error: 0.0
10 loops, best of 3: 171 ms per loop
Method 3 # mine
Error: 0.0
1000 loops, best of 3: 820 µs per loop
Using GPU
Method 1
Error: 0.0
10 loops, best of 3: 69 ms per loop
Method 2
Error: 0.0
10 loops, best of 3: 168 ms per loop
Method 3
Error: 0.0
1000 loops, best of 3: 430 µs per loop
Here is the whole script:
import torch
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
lengths = torch.randint(1, 100, (10000,))
sequences = [torch.randn(i, 2) for i in lengths]
ground_truth = torch.stack([seq[-1] for seq in sequences])
packed = pack_sequence(sequences, enforce_sorted=False)
def method1(packed): # TrentBrick Mar 28, 2019
output, input_sizes = pad_packed_sequence(packed, batch_first=True)
last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])
last_seq_items = output[range(output.shape[0]), last_seq_idxs, :]
return last_seq_items
def method2(packed): # BramVanroy Mar 28, 2019
output, input_sizes = pad_packed_sequence(packed, batch_first=True)
last_seq_items = [output[e, i-1, :].unsqueeze(0) for e, i in enumerate(input_sizes)]
last_seq_items = torch.cat(last_seq_items, dim=0)
return last_seq_items
def method3(packed): # mine
sum_batch_sizes = torch.cat((
torch.zeros(2, dtype=torch.int64),
torch.cumsum(packed.batch_sizes, 0)
))
sorted_lengths = lengths[packed.sorted_indices]
last_seq_idxs = sum_batch_sizes[sorted_lengths] + torch.arange(lengths.size(0))
last_seq_items = packed.data[last_seq_idxs]
last_seq_items = last_seq_items[packed.unsorted_indices]
return last_seq_items
print('Using CPU')
ground_truth = ground_truth.cpu()
packed = packed.cpu()
lengths = lengths.cpu()
for i, method in enumerate([method1, method2, method3]):
print('Method', i+1)
print('Error:', torch.norm(ground_truth - method(packed)).item())
%timeit method(packed)
print()
print('Using GPU')
ground_truth = ground_truth.cuda()
packed = packed.cuda()
# lengths = lengths.cuda()
for i, method in enumerate([method1, method2, method3]):
print('Method', i+1)
print('Error:', torch.norm(ground_truth - method(packed)).item())
%timeit method(packed)
print()