DataParallel divided into blocks?

For model-specific reasons, I have to apply DataParallel on multiple separate blocks, instead of on one single model. Then these blocks chain together.

Block1 -> Block2 -> etc.

Each are Pytorch modules.

Now the problem is, that DataParallel collates the parallelized batches at each intermediate block!

Block1 (Parallelized) -> Collate -> Block2 (Parallelized) -> Collate -> etc.

This collating is unnecessary and slow but an inadvertent byproduct of this modulation.

Is there any way to collate just-in-time or do some adaptive collate method that keeps the batches parallelized through all of the intermediate blocks and only collates at the end at BlockN?

I’m unsure what the exact use case is, but if you want to avoid the gather/scatter operations between the DataParallel blocks, you could just wrap all blocks into a DataParallel module.
Could you explain why this wouldn’t work as it seems you have a specific limitation for this use case?