Handling sparse batches to speed up training

Hi everyone,

I’m working on a use case where the input to my model is a tensor of size (B, K, context_length). Here:

  • B is the batch size,
  • K represents different sources of information,
  • context_length is the context length

For any given sample, only a subset of these K sources contain meaningful information, while the rest are padded. When the sparsity in K is high (i.e., most sources are padded), this leads to inefficiencies in training due to unnecessary computations.

Proposed Idea

To address this, I’m thinking of implementing a custom DataLoader. The idea is:

  1. For each sample, load only the non-padded (dense) parts of the K dimension.
  2. While filling a batch, ensure that all sources for a single sample remain together (i.e., no scattering across batches). Continue sampling until the batch is complete.

The goal is to avoid padded computations.

Questions

  • Has anyone implemented something similar, or is there a built-in PyTorch feature that could help with this?
  • Are there any suggestions or best practices for efficiently implementing such a custom DataLoader?
  • To implement a custom dataloader, you would have to do it in cpp and then maybe bindings.
  • Consider using some kind of clever indexing. Or find someone who can do it for you. Knowledge of data structures is essential. if you can keep your lookups O(1), that would be amazing.

Thank you for the contribution. The problem can be addressed easily using a custom Sampler that shuffles the dataset and composes batches before the epoch starts. The batching logic ensures that:

  1. Samples remain coherent: All sources for a single sample are kept together within the same batch, avoiding scattering across batches.
  2. Batch size constraints are respected: The sum of non-zero entries in the K dimension across all samples in a batch does not exceed the specified max_batch_size.

Zero entries can then be removed.