Hi everyone,
I’m working on a use case where the input to my model is a tensor of size (B, K, context_length)
. Here:
B
is the batch size,K
represents different sources of information,context_length
is the context length
For any given sample, only a subset of these K
sources contain meaningful information, while the rest are padded. When the sparsity in K
is high (i.e., most sources are padded), this leads to inefficiencies in training due to unnecessary computations.
Proposed Idea
To address this, I’m thinking of implementing a custom DataLoader
. The idea is:
- For each sample, load only the non-padded (dense) parts of the
K
dimension. - While filling a batch, ensure that all sources for a single sample remain together (i.e., no scattering across batches). Continue sampling until the batch is complete.
The goal is to avoid padded computations.
Questions
- Has anyone implemented something similar, or is there a built-in PyTorch feature that could help with this?
- Are there any suggestions or best practices for efficiently implementing such a custom
DataLoader
?