For model-specific reasons, I have to apply DataParallel on multiple separate blocks, instead of on one single model. Then these blocks chain together.
Block1 -> Block2 -> etc.
Each are Pytorch modules.
Now the problem is, that DataParallel collates the parallelized batches at each intermediate block!
Block1 (Parallelized) -> Collate -> Block2 (Parallelized) -> Collate -> etc.
This collating is unnecessary and slow but an inadvertent byproduct of this modulation.
Is there any way to collate just-in-time or do some adaptive collate method that keeps the batches parallelized through all of the intermediate blocks and only collates at the end at BlockN
?