Combine Model sharding with Data Parallelism


I am currently looking into getting the BERT architecture to run on GPUs. By the nature of the very long sequences (2200 chars) in my target Dataset even the small model cant fit a singular batch into 16GB memory.
for this reason I split up the transformer blocks onto 2 different GPUs, which actually works quite well.

Would it be possible to combine this approach with Data Parallelism a la nn.DataParallel()? Ideally I would like to have 4 groups of 2 gpus calculating in parallel

Have a look at this example, where I’ve created a small example using model sharding and nn.DataParallel.