Debugging DataParallel, no speedup and uneven memory allocation

With one GPU and a batch size of 14 an epoch on my data set takes about 24 minutes. With 2 GPUs and a batch size of 28 it’s still taking 24 minutes per epoch. Any suggestions on what might be going wrong? Does the batch normalization layer try to normalize across both GPUs and thus add large amounts of extra memory traffic? Please say it doesn’t.

Thanks.

Top shows 2 CPUs saturated:

Tasks: 255 total,   1 running, 254 sleeping,   0 stopped,   0 zombie
%Cpu(s): 16.3 us,  2.5 sy,  0.1 ni, 81.1 id,  0.0 wa,  0.0 hi,  0.0 si,  0.0 st
KiB Mem : 65885928 total, 40001592 free, 11878640 used, 14005696 buff/cache
KiB Swap: 67017724 total, 67017724 free,        0 used. 52840116 avail Mem 

  PID USER      PR  NI    VIRT    RES    SHR S  %CPU %MEM     TIME+ COMMAND                                                                           
 4622 mmacy     20   0 49.225g 5.809g 977604 S 200.0  9.2 111:32.30 work/vnet.base.

The memory allocation on the two GPUs is also uneven. If they’re both doing the same operations with the same batch size, why is GPU1 using 1/3rd more memory than GPU0?

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 367.57 Driver Version: 367.57 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 TITAN X (Pascal) Off | 0000:01:00.0 On | N/A |
| 51% 82C P2 74W / 250W | 7906MiB / 12186MiB | 99% Default |
+-------------------------------+----------------------+----------------------+
| 1 TITAN X (Pascal) Off | 0000:02:00.0 Off | N/A |
| 47% 78C P2 107W / 250W | 10326MiB / 12189MiB | 95% Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 1086 G /usr/lib/xorg/Xorg 105MiB |
| 0 8469 C work/vnet.base.20170316_0434 7797MiB |
| 1 8469 C work/vnet.base.20170316_0434 10323MiB |
+-----------------------------------------------------------------------------+


I also see that parallel_apply in data_parallel relies on python threading, which isn’t very worthwhile given how much of the code has to run under the GIL. The only way to get any sort of reasonable parallelism while using regular python GIL protected code is to run separate python processes.

Are other people actually seeing a speedup from DataParallel? I think I’m probably only seeing one thread make progress at a time.

1 Like

No, it doesn’t. I can’t understand what’s the problem. Are you expecting a larger speedup? You’ve doubled the amount of computing power along with the input size, so the time staying constant is a great result.

Maybe you have some conditional branching in your model?


our multi-GPU code is nearly as fast as it can be. You still need the GIL to execute most of the autograd operations so it doesn’t matter if we don’t use Python threads. In our CNN benchmarks we have perf very similar to that of Caffe2, and faster than Lua Torch, which has very good multi-GPU implementation.

Usually “data parallel” means data operations run in parallel, but here data parallel only means that the forward passes, the fast part, have any parallel component. That’s just not very useful - as larger batches can reduce the convergence rate. So I’m using twice the power, generating twice the heat and am getting no real benefit.

The use of python makes parallelism unduly hard but one could run the autograd pass in separate processes just like hogwild, but without sharing the model weights by having each process broadcast the updates from the cost function’s backward pass. Each process would average in the updates from other GPUs.

Yes, and the operations run in parallel indeed. Both in forward and in backward. I’m not sure how is that not useful and what’s the problem. Data parallelism is a well known term used to describe weak scaling of these models, and this is exactly what happens in here. It’s very useful, because in some cases you need batch sizes so large, that they don’t fit on a single GPU. The behaviour is exactly what’s wanted in a lot of situations.

It doesn’t make it hard, I really don’t understand what’s bothering you. You can do hogwild training, even on GPUs if you wan. No one is stopping you. We even have a hogwild example in the examples repo.

2 Likes

as adam pointed out, DataParallel is very very well defined in Deep Learning. We split the mini-batches over multiple GPUs, and accumulate the gradients at the end from all the GPUs before doing the optimization step.

Getting a linear speedup when doubling the batch size is the best case scenario, and you are hitting that.

If there was a confusion in terminology, i hope it’s clarified now.

2 Likes

I don’t think @mattmacy is getting linear speed-up, as his epoch time does not change, not time-per-minibatch. It’s more like no speed-up at all (same time per epoch indicates twice-as-big time for twice-as-big minibatch).
There are many scenarios where speed-up from data parallel would not be that great - e.g. OpenNMT example hardly benefits from data parallel.

1 Like

@ngimel

batch-size 14 : 1 GPU: 24 mins
batch-size 28: 2 GPU: 24 mins

in terms of weak-scaling, this seems appropriate right? We cannot expect any better speedup than this theorerically.

You expect your epoch to become twice as fast with two GPUs. Time per minibatch should stay constant with perfect weak scaling, but you should have two times less minibatches per epoch (if dataset consists of 140 samples, it is 10 minibatches with minibatch size of 14, and 5 minibatches with minibatch size of 28).

oh, my bad. I got confused a whole lot because all my life I am used to seeing mini-batch times. Sorry for the misunderstanding @mattmacy , I apologize.

But as @ngimel and @apaszke pointed out, there are many scenarios in which DataParallel is not great, especially if you have too little compute or if you have too many parameters in your model.

1 Like

@ngimel thank you! Yes when I go from 340 mini batches to 170 minibatches I expect the wall clock time to drop from 22 minutes to 11 minutes. Instead I’m seeing no change.

@smth I think I’m missing something, DataParallel only splits up the forward pass, I would think that in order to see a proper speed up during training I’d need to encapsulate the whole train loop so that all weights from the loss backward pass were averaged and then propagated. This snippet is the logic I’m referring to: http://docs.chainer.org/en/latest/tutorial/gpu.html#data-parallel-computation-on-multiple-gpus-without-trainer Is there some equivalent to their ParallelUpdater that I’m overlooking?

Otherwise, I have 61989982 parameters in my model. I guess that’s too many parameters? Why would DataParallel be rate limited by that? And is it possible I’m doing something wrong such that I’m ending up with more parameters than I mean to have? It’s just 5 level encoder-decoder FCN with 1-3 convolutions at each level and a skip connection from the output of the encoder levels to the level with same resolution and number of channels in the decoder stage. See the diagram from the paper https://github.com/mattmacy/vnet.pytorch/blob/master/images/diagram.png if my description doesn’t make sense.

Thank you for your time.

DataParallel also distributes backward pass, it is hidden in autograd. DataParallel has to broadcast and reduce all the parameters, so parallelization efficiency decreases when you computation time is small and you have a lot of parameters.

1 Like

in the backward pass of DataParallel, we reduce the weights from GPU2 onto GPU1.

Our DataParallel algorithm is roughly like this:

in forward:

  • scatter mini-batch to GPU1, GPU2
  • replicate model on GPU2 (it is already on GPU1)
  • model_gpu1(input_gpu1), model_gpu2(input_gpu2) (this step is parallel_apply)
  • gather output mini-batch from GPU1, GPU2 onto GPU1

in backward:

  • scatter grad_output and input
  • parallel_apply model’s backward pass
  • reduce GPU2 replica’s gradients onto GPU1 model
  • Now there is only a single model again with accumulated gradients from GPU1 and GPU2

  • gather the grad_input

Hence, unlike in Chainer, you do not actually have to have a separate trainer that is aware of DataParallel.
Hope this makes it clear.

wrt why your model is slower via DataParallel, you have 61 million parameters. So, I presume you have some Linear layers at the end (i.e. fully connected layers). Put them outside the purview of DataParallel to avoid having to distribute / reduce those parameter weights and gradients. Here is an example of doing that:

https://github.com/pytorch/examples/blob/master/imagenet/main.py#L68
https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py

When training AlexNet or VGG, we only put model.features in DataParallel, and not the whole model itself, because AlexNet and VGG have large Linear layers at the end of the network.

Maybe your situation is similar?

11 Likes

Is it necessary to replicate all of the gradients? Couldn’t you just replicate the output of the backward pass of just the loss function and then average the results?

There are no fully connected layers. I guess 3d convolutions with >= 128 channels have an exorbitant number of parameters. I may have made an error in going up to 512 when I only meant to go up to 256, so at least you’ve prompted me to take a closer look at the model.

Thanks again.

This is the result of printing the number of parameters for each of the basic elements
Conv3d: 2016
BatchNorm3d: 32
Conv3d: 4128
BatchNorm3d: 64
Conv3d: 128032
BatchNorm3d: 64
Conv3d: 16448
BatchNorm3d: 128
Conv3d: 512064
BatchNorm3d: 128
Conv3d: 512064
BatchNorm3d: 128
Conv3d: 65664
BatchNorm3d: 256
Conv3d: 2048128
BatchNorm3d: 256
Conv3d: 2048128
BatchNorm3d: 256
Conv3d: 2048128
BatchNorm3d: 256
Conv3d: 262400
BatchNorm3d: 512
Conv3d: 8192256
BatchNorm3d: 512
Conv3d: 8192256
BatchNorm3d: 512
Conv3d: 8192256
BatchNorm3d: 512
ConvTranspose3d: 262272
BatchNorm3d: 256
Conv3d: 8192256
BatchNorm3d: 512
Conv3d: 8192256
BatchNorm3d: 512
Conv3d: 8192256
BatchNorm3d: 512
ConvTranspose3d: 131136
BatchNorm3d: 128
Conv3d: 2048128
BatchNorm3d: 256
Conv3d: 2048128
BatchNorm3d: 256
ConvTranspose3d: 32800
BatchNorm3d: 64
Conv3d: 512064
BatchNorm3d: 128
ConvTranspose3d: 8208
BatchNorm3d: 32
Conv3d: 128032
BatchNorm3d: 64
Conv3d: 8002
BatchNorm3d: 4
Conv3d: 6

@smth Revisiting it, the channel split numbers are in fact correct. I tried replacing each of the 5x5x5 filters with 2 3x3x3 filters - which reduced the parameters to 27 million, but it actually increased the memory consumption on the GPU and provided no speed up. So I guess I’ll just have to stick with using the additional GPU for hyperparameter search.

Thanks.

@smth I tried removing some of the largest convolutional layers - with no ill effect - and now DataParallel epochs are taking 15 minutes. Thanks for the explanations!

1 Like

sweet, that’s great news.

Soumith and Adam, I am having a great time exploring PyTorch! Thanks for the awesome library.

I am trying to saturate a 64-core/256-thread CPU in addition to the GPUs. Any pointers on how I can extend Data_parallel.py to create 3 scatters on GPU0, GPU1, and CPU(0-255)?

With Keras I modified this script to saturate the CPU:

@FuriouslyCurious So you want to run the model in parallel on two GPUs and all cores? We don’t have any utility for that, and I don’t think it’s even worth it :confused: The code will be more complex and you’ll probably see hardly any speedup.

1 Like

Due to the statements in his thread, is there no speedup expected for classic DNNs @apaszke ?
I currently run speaker recognition DNN’s with pytorch and increasing the number of GPUs (e.g. 2) used, while at the same time increasing the batchsize ( 256 -> 512 ), does not affect the training time at all. Single GPU training on Switchboard takes me ~300 minutes for a full epoch with a single GPU, as well as with 2 GPUs and doubling the batchsize. The model is a 6 Layer 2048 node DNN. The 2 GPUs are utilized, e.g., are not idle at all.

So for DNN’s there is no speedup expected, if the model parameters are large?

@ngimel @smth I’m running into a situation where using DataParallel ends up training slower than without it and on one gpu (while keeping everything else constant). If I try to increase the batch size, I run out of memory. I’m running an encoder - decoder (with attention) model with 3 million parameters. When running on one gpu, I’m able to run a batch size of 2048 sequences which takes up about 6000mb out of the 6078mb. Whereas on two gpus (using DataParallel on all layers in the encoder and decoder), running the same batch size takes up 6070mb on GPU 1 and only 1022mb on GPU 2.

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 375.66                 Driver Version: 375.66                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 980 Ti  Off  | 0000:06:00.0     Off |                  N/A |
|  0%   51C    P2   167W / 300W |   6070MiB /  6078MiB |     73%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 980 Ti  Off  | 0000:07:00.0      On |                  N/A |
|  0%   45C    P2    93W / 300W |   1021MiB /  6075MiB |     35%      Default |
+-------------------------------+----------------------+----------------------+

I have tried putting the linear layers outside of DataParallel to no avail - machine ran out of memory.

I understand that the computations on GPU 1 require more memory than those on GPU 2, but I was expecting more memory to be used on GPU 2. Am I at the maximum capacity / performance? Is there anything I can do with this seq-to-seq model to train more batches at a time or shorten training time.