How to `pack_sequence` and `PackedSequence`

Hi! I can’t find a up to date example that uses pack_sequence and its output PackedSequence in the context of a RNN-like network. That is, I have a model that processes a sequence timestamp by timestamp.

As far as I understand, earlier functions pad_sequence, pack_padded_sequence, pad_packed_sequence can be replaced by the newer pack_sequence (at least for the common use case).

I think I got my head around on how to use pack_sequence and I wonder if you could provide some feedback on whether I am right. A minimal working example is:

import torch
from torch import nn
from torch.nn.utils.rnn import pack_sequence

BATCH_SIZE = 5
N_SAMPLES = 7


class Dataset(torch.utils.data.Dataset):
    def __init__(self, n=25):
        super().__init__()
        self.n = n

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        l = torch.randint(low=5, high=10, size=(1,))
        x = torch.arange(l.item(), dtype=torch.float)
        y = x * 2.0
        return x, y


def collate_pack(batch):
    xx, yy = zip(*batch)

    return (
        pack_sequence(xx, enforce_sorted=False),
        pack_sequence(yy, enforce_sorted=False),
    )


dataset = Dataset(n=N_SAMPLES)

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SIZE, collate_fn=collate_pack
)

model = nn.Linear(in_features=1, out_features=1)

for i_batch, batch in enumerate(dataloader):
    print("Batch=", i_batch)
    print("-" * 20)

    x_pkd, y_pkd = batch

    idx_start = 0
    for t, batch_size in enumerate(x_pkd.batch_sizes):
        print("Processing time t=", t)

        x, y = (
            x_pkd.data[idx_start : idx_start + batch_size],
            y_pkd.data[idx_start : idx_start + batch_size],
        )
        x = x.unsqueeze(1)  # (batch) -> (batch, seq_length = 1)
        y = y.unsqueeze(1)  # (batch) -> (batch, seq_length = 1)
        print("\tInput:", x.shape, y.shape)

        y_hat = model(x)
        print("\tOutput:", y_hat.shape)

        idx_start += batch_size
        print()

print("Done!")

The idea is that:

  • The dataset contains sequences of (x, y) of different length
  • Because the sequence length is different, the dataloader need a different collate function. Here is where I use pack_sequence
  • When I use dataloader, the batch will be a PackedSequence. As I loop through each element in the sequence, the effective batch size will change (i.e., later timestamps will contain only the values of the longer sequences)

It seems to work, but I don’t like very much the pattern

x_pkd.data[idx_start : idx_start + batch_size],

Do you know if PackedSequence has already a method to get the right values directly?