Can I group batches by length?

Hi everyone.

I am working on an ASR project, where I use a model from HuggingFace (wav2vec2). My goal for now is to move the training process to PyTorch, so I am trying to create everything that HuggingFace’s Trainer() class gives.

One of these utilities is the ability to group batches by length and combine this with dynamic padding (via a data collator). To be honest however, I am not sure how to even begin this. Would I need to create a custom Dataloader class and alter it, so that every time it gives me batch sizes of lengths as close as possible?

The input are 1-D arrays that represent the raw waveform of a .wav file. An idea I had, was to somehow sort the data from shortest to longest (or the opposite), and each time extract batch_size samples from them. This way, the first batch will consist of samples with the biggest lengths, the second batch will have the second biggest lengths, etc.

Nevertheless, I am not sure how to approach this implementation. I also searched online but did not manage to find something already implemented. Any advice will be greatly appreciated.

Thanks in advance.

EDIT: Now that I think about it, could I possibly do it somehow in the data collator function?

This is not in the standard Pytorch tooling, but the last time I needed this, I moved the batching to the dataset and then just squeezed the singleton additional batch dimension from the system collate. This seemed reasonably straightforward and performant to me.

Best regards


Hello, and thanks for the answer!

I sadly am not sure what you mean by saying:

Could you please elaborate a bit?
Thanks in advance.

This was for a client project involving CT scans, so the original code is not available. But here is a mickey-mousey example:

The classic way similar to this example:

ds_classic =, 100))
len(ds_classic), ds_classic[0][0].shape

If we teach our dataset to return batches (which we arrange to be stratified / bucketed in any way we like), we have roughly this:

ds_batched =, 32, 100)) # 100 batches a 32 samples
len(ds_batched), ds_batched[0][0].shape

So between my first answer and this, I found passing batch_size=None:

dl_classic =, batch_size=32, shuffle=True)
dl_batched =, batch_size=None, shuffle=True)

So after this, we get the same shape of batches for both.

batch_classic, = next(iter(dl_classic))
prebatched_batch, = next(iter(dl_batched))

The advantage of going through the dataloader is that you can have the num_workers magic and all.

Another option is to use a custom sampler, and there will be people who prefer that, but in my opinion, the above is hard to beat in terms of simplicity. The drawback is that unless you manually shuffle the batch elements in each epoch, you get the same batches (but in different order). This may or may not be OK.

Best regards



Thank you for the detailed answer. I think I understand what’s going on in the provided code, but I fail to see how this addresses my issue.

To give more details, just as your artificial input here is 32000 samples of 100 dimensions each, in my problem let’s say I also have 32000 samples, but the dimensions vary. A data point might have 35 features, another might have 80 etc.

The code you posted, by default has batches of equal length, since they all are 100-D. However, I want to do the same in my case, where the sizes are not the same as displayed above. Therefore, what I wish is to find a way to group batches of equal/similar lengths together when passed to the model.

If for example my batch size is 4, I want each 4 samples to be as close as possible in terms of length. I do not wish, for instance, to have a batch where the 4 files have the following input sizes: 80, 20, 35, 10. Instead, I want something like this: 80, 75, 82, 90 (this is just an example).

Unless of course, you mean that I should first sort my data from longest to shortest (or in any case, arrange them as I wish) and then do the ‘batch_size=None’ DataLoader?

If so, forgive me for not understanding earlier.

Thanks again for your responses!

Unless of course, you mean that I should first sort my data from longest to shortest (or in any case, arrange them as I wish) and then do the ‘batch_size=None’ DataLoader?

Yes, this is what I’d do. But then, again, it is what I did the last time I looked at this problem, if you think it is not a good fit to your application, you should not hesitate to ignore the suggestion.

Best regards


Hi, it is certainly something I already consider. Still haven’t found another solution sadly. In anycase, thank you for your help, I appreciate it.

But so what is the thing holding you back?
My strategy would be to have a preprocessing step that collects information like length in a CSV or so and then you can sort on the length column. Personally, I like to lazily use pandas for doing things with the metadata. If your data has all the same sample frequency you can just use file size and scan your files when initializing the dataset.
With that information, you can group the data to form batches.

This is what I plan to try, based on what we discuess yesterday. I just don’t have enough exprerience yet to know beforehand if it will work as I want it to.

One of the issues I’m trying to understand whether it will cause problems, is something you mentioned. By doing

dl_batched =, batch_size=None, shuffle=True)

I will get the same batches, albeit in different order, every time. Will this be an issue? On the other hand, even Trainer()'s group_by_length() has to be doing something similar. I tried checking its implementation in the docs but I am not sure what it is doing to be honest.

I will eventually try the method you proposed to me here though, just to see if it works as expected!

EDIT: there is something in the dataloader docs called sampler, which “defines the strategy to draw samples from the dataset”. Perhaps creating something custom related to this might offer a more ‘expected’ solution to my problem.

So you could include shuffeling of the “similar length” items and manually call that for each epoch.

The Sampler is a way to do this, but so the reason I personally like datasets is that it seems more straightforward in the sense that the dataset seems to be an interface that people are used to inspecting for correctness (but of course, you can do the same by running the DL).

1 Like

I agree with your points. For now I havent researched this further because some other issues came up, however I’ll get back to it tomorrow or the day after tomorrow!

If I manage to properly address this challenge or if further questions appear, I’ll post it here to let you know. Thanks for the advice.