Pytorch - One GPU uses more training - How to mitigate?

So my pytorch model looks like this:

gathered_output = torch.nn.parallel.data_parallel(model, model_input, range(ngpu))
loss = F.loss(gathered_output, ground_truth)

I am doing some large image segmentation task. And GPU 0 uses almost 20% more memory. Is there a way I can mitigate this? The data gathered back isn’t that much just a [8,128,300,400] float tensor in cuda. I have 4 GPU’s, and each GPU works on 2 batches. But I guess the gradients are all gathered back too?

So here I used the functional version of data_parallel, is there a difference to using the wrapped version?

If i compute the loss on each gpu, and only gather back the losses will i save more memory? please help

You could try to adapt the tips from @Thomas_Wolf in his blog post, which also explains the work flow of nn.DataParallel or alternatively you could use nn.DistributedDataParallel.