Ability to sort dataset batch to reduce exccessive padding?

The current update to torchtext makes it a nightmare to understand and all the legacy code is not supported with newer versions of PyTorch and Pytorch-Lightning.

As a result, Iā€™d like to switch to the PyTorch - DataLoader and Dataset. One of the features I find massively lacking is the need to provide a max_length metric to truncate or pad data items within a sampled batch.

Is there a way where we can select a batch from a descending list of data items so that the variability of lengths within each batch is minuscule and as a result, we can get away with minimal padding and no truncating?

One way I can think of is to rewrite the dataset in a particular-sort. Chose a batch without sorting but sort within the batch? Can we do it in an online streaming sort of way?

This topic by @vdw might be helpful. :slight_smile:

1 Like

Thanks! I found some good ideas here - Tensorflow-esque bucket by sequence length