DataLoader for various length of data

I’ve been working on implementing a seq2seq model and tried to use torch.utils.data.DataLoader to batch data following the Data Loading and Processing Tutorial. It seems DataLoader cannot handle various length of data. Or are there other ways to batch different length of data?

15 Likes

you could create a transformation that trims / pads each sample to a specific length and then use the pack padded sequence function

3 Likes

You need to customize your own dataloader.

What you need is basically pad your variable-length of input and torch.stack() them together into a single tensor. This tensor will then be used as an input to your model.

I think it’s worth to mention that using pack_padded_sequence isn’t absolutely necessary. pack_padded_sequence is kind of designed to work with the LSTM/GPU/RNN from cuDNN. They are optimized to run very fast.

But, if you have your own proposed method that prevents you from using standard LSTM/GPU/RNN, as mentioned here:

The easiest way to make a custom RNN compatible with variable-length sequences is to do what this repo does (masking) GitHub - jihunchoi/recurrent-batch-normalization-pytorch: PyTorch implementation of recurrent batch normalization

2 Likes

Thx sir. Do you mean getting a batch of data and padding them manually? That’s exactly what I’m doing. I’m just wondering if there’s a ‘pytorch’ proper way to do this.

I meant to create your own Dataset class and then do a transform to pad to a given length. An example of a custom dataset class below. The idea would be to add a transform to that which pads to tensors so that upon every call of getitem() the tensors are padded and thus the batch is all padded tensors. You could also have the getitem() function return a third value, which is the original length of the tensor so you can do masking.

1 Like

I was wondering if there is a more efficient way of padding sequences. The easiest option is to just pad all sequences to the max length possible, currently I implemented my own Dataset object and use a Transform that pads all sequences to the same length. But is there a way to do that per batch and not globally for the whole dataset (pad the batch when DataLoader samples the batch)? Sounds like I need to create a DataLoader?

Edit:
I found a possible solution at: http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html.
Specifically, you can implement your own version of BatchSampler to padd the according to the longest sequence in the batch. I will post my implementatino when done.

3 Likes

I think you want to use the collate_fn function in the DataLoader class.

I did one there with packed sequences. I don’t know if this is the fastest way, but it would accomplish what you want to do. Also you could use any of the pre-built samplers that you wanted.

3 Likes

Thanks David, collate_fn was a good direction :slight_smile:. I wrote a simple code that maybe someone here can re-use. I wanted to make something that pads a generic dim, and I don’t use an RNN of any type so PackedSequence was a bit of overkill for me. It’s simple, but it works for me.

def pad_tensor(vec, pad, dim):
    """
    args:
        vec - tensor to pad
        pad - the size to pad to
        dim - dimension to pad

    return:
        a new tensor padded to 'pad' in dimension 'dim'
    """
    pad_size = list(vec.shape)
    pad_size[dim] = pad - vec.size(dim)
    return torch.cat([vec, torch.zeros(*pad_size)], dim=dim)


class PadCollate:
    """
    a variant of callate_fn that pads according to the longest sequence in
    a batch of sequences
    """

    def __init__(self, dim=0):
        """
        args:
            dim - the dimension to be padded (dimension of time in sequences)
        """
        self.dim = dim

    def pad_collate(self, batch):
        """
        args:
            batch - list of (tensor, label)

        reutrn:
            xs - a tensor of all examples in 'batch' after padding
            ys - a LongTensor of all labels in batch
        """
        # find longest sequence
        max_len = max(map(lambda x: x[0].shape[self.dim], batch))
        # pad according to max_len
        batch = map(lambda (x, y):
                    (pad_tensor(x, pad=max_len, dim=self.dim), y), batch)
        # stack all
        xs = torch.stack(map(lambda x: x[0], batch), dim=0)
        ys = torch.LongTensor(map(lambda x: x[1], batch))
        return xs, ys

    def __call__(self, batch):
        return self.pad_collate(batch)


to be used with the data loader:
train_loader = DataLoader(ds, ..., collate_fn=PadCollate(dim=0))

31 Likes

Felix, I think your code only pads correctly if dim=0. This is because in the pad vector in the pad_tensor function has *vec.size()[1:] hardcoded into it. I think you need to create a vector that is pad - vec.size(dim) in the dim dimension and not always in the zeroth dimension. However, I could be wrong. I adapted the code to work with python3 and added the ability to pad with different values, so I may have screwed something up in the process.

David, you are correct, I updated the pad function to work with any dim, thanks.

1 Like

If you are going to pack your padded sequences later, you can also immediately sort the batches from longest sequence to shortest:

def sort_batch(batch, targets, lengths):
    """
    Sort a minibatch by the length of the sequences with the longest sequences first
    return the sorted batch targes and sequence lengths.
    This way the output can be used by pack_padded_sequences(...)
    """
    seq_lengths, perm_idx = lengths.sort(0, descending=True)
    seq_tensor = batch[perm_idx]
    target_tensor = targets[perm_idx]
    return seq_tensor, target_tensor, seq_lengths

def pad_and_sort_batch(DataLoaderBatch):
    """
    DataLoaderBatch should be a list of (sequence, target, length) tuples...
    Returns a padded tensor of sequences sorted from longest to shortest, 
    """
    batch_size = len(DataLoaderBatch)
    batch_split = list(zip(*DataLoaderBatch))

    seqs, targs, lengths = batch_split[0], batch_split[1], batch_split[2]
    max_length = max(lengths)

    padded_seqs = np.zeros((batch_size, max_length))
    for i, l in enumerate(lengths):
        padded_seqs[i, 0:l] = seqs[i][0:l]

    return sort_batch(torch.tensor(padded_seqs), torch.tensor(targs).view(-1,1), torch.tensor(lengths))

This assumes that your Dataset spits out something like

def __getitem__(self, idx):
        return self.sequences[idx], torch.tensor(self.targets[idx]), self.sequence_lengths[idx]

And the you pass the pad_and_sort collator to the DataLoader as:

train_gen = Data.DataLoader(train_data, batch_size=128, shuffle=True, collate_fn=pad_and_sort_batch)

5 Likes

For future readers, if like me you were looking for the above (which is great) but also to batch your sequences by their length to minimize padding necessary then I wrote a Batch Sampler for this: Tensorflow-esque bucket by sequence length

4 Likes

To answer the original question, you can pass a (simple and short) custom collate function to the data loader that uses pack_sequence.

pack_sequence does not require the sequences to be padded or sorted by length, so it is simpler to use.

Here is the code that does this (based on this answer to a similar question: How to create a dataloader with variable-size input )

from torch.nn.utils.rnn import pack_sequence
from torch.utils.data import DataLoader

def my_collate(batch):
    # batch contains a list of tuples of structure (sequence, target)
    data = [item[0] for item in batch]
    data = pack_sequence(data, enforce_sorted=False)
    targets = [item[1] for item in batch]
    return [data, targets]

# ...
# later in you code, when you define you DataLoader - use the custom collate function
loader = DataLoader(dataset,
                      batch_size,
                      shuffle,
                      collate_fn=my_collate, # use custom collate function here
                      pin_memory=True)
5 Likes

This is how I solved it:

def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    ## get sequence lengths
    lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
    ## padd
    batch = [ torch.Tensor(t).to(device) for t in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    ## compute mask
    mask = (batch != 0).to(device)
    return batch, lengths, mask

Related posts:

bucketing:


Stack overflows version:


crossposted: https://www.quora.com/unanswered/How-does-Pytorch-Dataloader-handle-variable-size-data

1 Like

Here’s a collator I use, it works for tensors of any dimension:

class ZeroPadCollator:

    @staticmethod
    def collate_tensors(batch: List[torch.Tensor]) -> torch.Tensor:
        dims = batch[0].dim()
        max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
        size = (len(batch),) + tuple(max_size)
        canvas = batch[0].new_zeros(size=size)
        for i, b in enumerate(batch):
            sub_tensor = canvas[i]
            for d in range(dims):
                sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
            sub_tensor.add_(b)
        return canvas

    def collate(self, batch, ) -> List[torch.Tensor]:
        dims = len(batch[0])
        return [self.collate_tensors([b[i] for b in batch]) for i in range(dims)]

Then I simply use:

    zero_pad = ZeroPadCollator()
    loader = DataLoader(train, args.batch_size, collate_fn=zero_pad.collate)```
1 Like

For the others who might have the same issue with RNN and multiple lengths sequences, here is my solution if your dataset __getitem__ method returns a pair (seq, target) :

from torch.nn.utils.rnn import pad_sequence

def collate_fn_pad(list_pairs_seq_target):
    seqs = [seq for seq, target in list_pairs_seq_target]
    targets = [target for seq, target in list_pairs_seq_target]
    seqs_padded_batched = pad_sequence(seqs)   # will pad at beginning of sequences
    targets_batched = torch.stack(targets)
    assert seqs_padded_batched.shape[1] == len(targets_batched)
    return seqs_padded_batched, targets_batched
    
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn_pad)

for seq, labels in dataloader:
        y_pred = rnn(seq)
1 Like

Out of curiosity. You this at test time only, right? During training you may want truly stochastic batches.

1 Like