Parallel Training on multiple GPUs without first GPU saturation

tl;dr:
I am trying to increase the batch size for training ResNet50 with the pytorch imagenet training example, but since the first GPU holds both the model and data, I can’t increase the batch size without saturating the first GPU. How can put data just on the other GPUs?

Details:
I’m doing this over 4 GPUs, currently with a batch size of 64.

You can see in the screenshot above about how the ResNet model and Imagenet data are distributed across the four GPUs. You can also see that GPU 0 has more memory used than 1,2,3 because it holds the resnet model in addition to the data.
How do I go about putting the data only on GPUs 1,2,3 while keeping the model on GPU 0? So that I can increase the batch size to fully use all the GPUs.

I assume you are using DataParallel, and that’s why the input have to stay on cuda:0? Will DistributedDataParallel (DDP) (Distributed Data Parallel — PyTorch 1.10 documentation) help? With DDP, each process can have its own dataloader loading data into its own device.