TL;DR: Torch7 model took 0.5s/batch; same model on PyTorch runs 1.3-6s/batch(x2-3 slower). How should we diagnose the bottleneck?
We have a DenseNet-like model forked from fb.resnet.torch receiving 720x720 images as input. For our original model in Torch7, a forward pass on a 8 TitanXp node took around 0.5s with a batchsize of 112. But after migrating the same model structure to PyTorch, while the backward pass became slightly faster(0.7s to 0.5s) and the max batchsize capacity increased(112 to 120), the forward pass became x2-3 slower(1.4-7s/batch), even with the same batchsize. (By forward pass, I mean the line that goes
output = model(x).)
Also some interesting facts:
- We initially tried Tensorflow for migration, but its max batchsize was too low for our need.
- The default PyTorch implementation of DenseNet in torchvision, while having less layers than our model, is about 10% slower with 50% lower max batchsize.
- GPU utilization falls more frequently for PyTorch than Torch7; while Torch7 model reaches 80C+ temperature, PyTorch model stays at 60-70C.
- A PyTorch single GPU run with 16 images per batch gives 0.2s/batch, which is just the original speed divided by GPU count(seemingly no parallelized gains of using 8 GPUs).
Some methods we tried for pytorch:
- Using Docker to control the experiment environment, except CUDNN version(5 for Torch7, 6 for PyTorch)
- Attempts at optimizing PyTorch code according to the docs: namely, using nn.DataParallel, nn.Sequential, inplace=Trues, pinned_memory and async, making sure nccl is working, adjusting num_workers, mimicking structure from torchvision examples, turning cudnn.benchmark on, not DataParallelling the last fully-connected layers.
I understand that our input size is unusually large, and that PyTorch is still in beta phase; and also, since we cannot disclose our model structure, the question must generalize: Are there pitfalls to avoid and hacks to try in optimizing model speed?
we are aware of some of the multi-GPU cases on the extreme might have these issues.
We’re fixing these perf for the next release.
In the meanwhile, you can try using nccl2 instead of default nccl.
Download nccl2, and then do:
NCCL_ROOT_DIR=[where nccl2 is] python setup.py install
Also, if you can provide us with some benchmark scripts, we can update you when this perf gets fixed.
Contribute it to: https://github.com/pytorch/benchmark
Thanks Soumith! I tried rebuilding with nccl2 as suggested, but unfortunately without much good news. nccl2 indeed made the per-batch compute time more stable than the fluctuations of nccl1, but total speed gains were very small or none.
I’ve tracked the bottleneck down to the effect of the GIL on
replicate phases were not so harshly effected by the number of GPUs up to 8(even with async off), per-thread forward time in
parallel_apply increased almost linearly as the number of threads(GPUs) increased.
So I tried modifying the
DataParallel to spawn background processes instead of threads, and saw that it brought down the per-GPU forward time to the single-thread mode. However the cost of IPC, e.g. pickling and unpickling
Modules and input
Variables on each batch, limited the final result to be only slightly(around 10%) faster than the threaded version. And more importantly, I saw that
Variable.grad_fn is blanked out upon passing through ForkingPickler; having lost the computation graph, automatic backward was no longer possible.
As you suggested, I will try to generalize our case into a code that we can share into the benchmark repo. But if this is indeed the limitations of the GIL, the new
torch.jit might just be the solution. I’ll track the nightlies and report back if there’s some news.
Out of curiosity, are you doing any augmentation on 720px images? From what I understand data augmentation is not multi-threaded, even though loader is multi-threaded.
@FuriouslyCurious How are you applying your augmentation? If you implement it in the
Dataset object, e.g.
torchvision.datasets.ImageFolder, it gets spread unto separate worker processes in
DataLoader, running in parallel with the main process.