Hi everyone,
I’m working on a use case where the input to my model is a tensor of size (B, K, context_length). Here:
Bis the batch size,Krepresents different sources of information,context_lengthis the context length
For any given sample, only a subset of these K sources contain meaningful information, while the rest are padded. When the sparsity in K is high (i.e., most sources are padded), this leads to inefficiencies in training due to unnecessary computations.
Proposed Idea
To address this, I’m thinking of implementing a custom DataLoader. The idea is:
- For each sample, load only the non-padded (dense) parts of the
Kdimension. - While filling a batch, ensure that all sources for a single sample remain together (i.e., no scattering across batches). Continue sampling until the batch is complete.
The goal is to avoid padded computations.
Questions
- Has anyone implemented something similar, or is there a built-in PyTorch feature that could help with this?
- Are there any suggestions or best practices for efficiently implementing such a custom
DataLoader?