My model’s input examples are sequences of highly-variable length. If I pack a fixed number of these sequences into a batch tensor (which is allocated to accommodate the longest sequence), I eventually run into a case where too many long sequences end up in the same batch, and my model runs out of memory.
Finding a “safe” batch size is annoyingly time-consuming, because the above situation may arise many epochs into the training.
I would prefer to find a way to assemble my batches such that the sum of lengths of all sequences in them is roughly constant. How can I achieve this with PyTorch DataLoader?
The only way I could think of, is to create a data-aware batch sampler. However, in this case I would need to pre-scan all data to find how long each example is… The solution also needs to work with DistributedDataParallel… can I have a distributed batch sampler?
So, is there a canonical way to achieve variable batch sizing, which works best with PyTorch’s infrastructure?