Hi there,
I’m going to re-edit the whole thread to introduce a unlikely behavior with DataParallel
Right now there are several recent posts about this topic and I would like to summarize the problem.
Right now it seems there is an imbalaced usage of GPUs when calling DataParallel.
From my experience and other users’ explanations I will explain why this happens:
Using DataParallel there are 4 types of data to consider:
- Input
- Output
- Ground-truth
- Optimizer Parameters
DataParallel has a main GPU, which is the GPU where the model is stored.
As @Yuzhou_Song indicates in another post, DataParallel splits the batch in as many GPUs as choosen, copy the model in each of them, compute the forward pass independently and then collect back to one GPU the outputs of each GPU to calculate the loss instead of computing loss independently in each GPU. This is the main cause of imbalance memory usage. Consider that ground-truth and output (target of loss) must be in the same GPU
I discovered by my self that some optimizers requires lot of memory to save their parameters. However all these parameters are located in the already mentioned main GPU. This makes the problem worse.
There is a last reason. Model inputs are usually allocated to GPU by using .cuda()
, which usually points to the main GPU and generates more imbalance.
There are some ways to minimize this:
DataParallel have 2 arguments, device_ids
which allows to choose in which GPUs the model will be trained out of all available GPUs (CUDA_VISIBLE_DEVICES
) and output_device
which allows to choose in which GPU output will be stored.
Let’s imagine we have 3 devices, [0,1,2].
calling
model = DataParallel(model).cuda()
would set device0 as main gpu. Then, optimizers parameters will be stored here.
calling model = DataParallel(model,output_device=1).cuda()
and grountruth.cuda(1)
will collect all the outputs and compute loss in cuda:1
lastly, you can allocate inputs to cuda2.
This way the memory usage is distributed as much as possible.
Is there a way of solving this problem? @smth @ptrblck @albanD @soulless
I guess this behavior is very inconvenient.