Balancing GPU load with DataParallel

It seems that wrapping my classes in DataParallel(model), I see there’s an imbalance between the allocated memory on each GPU. Is there a way to balance them more? It seems this imbalance is making it more likely to cause OOM errors

This is expected as it’s a known drawback of nn.DataParallel which is one of the reasons why we recommend using DistributedDataParallel instead.

1 Like