Will pytorch native mixed precision training be slower than apex when there are multi-losses?

Hi,

A simplified version of my training pipe line is like this:

import torch.cuda.amp as amp

scaler = amp.Scaler()

optim.zero_grad()
with amp.autocast(enabled=True):
    logits1, logits2, logits3 = model(imgs)
    loss1 = criteria(logits1, labels)
    loss_aux = [criteria(logits2, labels), criteria(logits3, labels)]   
    loss = loss1 + sum(loss_aux)
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()

While the apex version is like this:

from apex import amp

optim.zero_grad()
logits1, logits2, logits3 = model(imgs)
loss1 = criteria(logits1, labels)
loss_aux = [criteria(logits2, labels), criteria(logits3, labels)]   
loss = loss1 + sum(loss_aux)

with amp.scale_loss(loss, optim) as scaled_loss:
    scaled_loss.backward()
optim.step()

I am using pytorch 1.6.0 with python3.6.9 on ubuntu 16.04 (docker container) with cuda 10.1.243/cudnn7.

There are two things that I feel puzzled:
The first one is that I received a warning like this:

/miniconda/envs/py36/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:123: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

though I am sure I called the lr_scheduler after the scaler.step(optim), and when I set enabled=False to autocast(enabled=False), there will not be this warning.

The second is that the pytorch native version is much slower than the apex version. When I use autocast of pytorch, the training time for 100 iter is 110s, and when I use apex, the training time is around 80s.

Is there any problem with my usage of this new feature and how could I solve the problem ?

The learning rate scheduler warning might be raised, if the gradients contained invalid values and the optimizer.step() was thus skipped.

Which opt_level in apex are you using and did you use the same CUDA and cudnn versions for the comparison?

Thanks for replying!! I am using apex O1 opt level, and yes I tried the new pytorch 1.6 feature on the same platform as the original apex codebase. The pytorch native fp16 training is slower than the apex based fp16 training.

Could you update to the CUDA10.2 binaries (which would ship with cudnn7.6.5.32) and post the general model architecture? You don’t need to post the exact model, if that wouldn’t be possible, but a “similar” architecture would be sufficient.

Hi,

I am happy that you would like to review my code. I created a docker image, and you can see my problem by pulling and running it:

$ docker pull coincheung/debug_zzy:first
$ nvidia-docker run -it --ipc=host coincheung/debug_zzy:first bash

To reproduce the apex training, the command is:

# cd /debug/bisenetv2
# CUDA_VISIBLE_DEVICES=6,7 python -m torch.distributed.launch --nproc_per_node=2 train.py

And in order to switch to the pytorch native fp16, the command is:

# cd /debug/bisenetv2
# CUDA_VISIBLE_DEVICES=6,7 python -m torch.distributed.launch --nproc_per_node=2 train_amp.py

On my platform, the training time of apex is around 85s/100iter, and the pytorch native fp16 traing time is around 115s/100iter.

I think the problem is likely to be associated with interpolation. When I replace interpolation with nn.PixelShuffle, the problem is no long exists.

Is 10.2 recommended for native AMP? If so, why?

We are working on fixing regressions for interpolations at the moment. One issue is tracked here.

Not just for AMP, but I would always recommend to use the latest CUDA version just to get all bug fixes, new features etc. Since CUDA11 is only available in the nightly binaries at the moment, I used the fallback of CUDA10.2 for the latest stable release.

Agreed, although it isn’t always evident to find the time to get all of my colleagues to use the latest PT+CUDA versions. Some of them are still on 0.4.1 + CUDA9.

I really have to find time to figure out how I can use docker images for my development projects. I feel it would make things a lot easier, rather than to have manually install a new venv with PyTorch and manually setting the correct paths for CUDA. Docker should do away with all that overhead but I haven’t figured out how to use it during development. If only we had an infinite amount of time!

(Sorry for topic hijack!)

If you don’t need to rebuild from source and would be alright using the stable release of the nightly binaries, you wouldn’t have to worry about CUDA at all, since the conda binaries and pip wheels ship with the CUDA, cudnn, NCCL etc. runtime libs.
You would only have to make sure your workstation has the right NVIDIA driver installed.

If you are often rebuilding PyTorch and need a proper CUDA setup, I would recommend to try out the pytorch docker containers (and rebuild inside them) so that you won’t have to figure out where to install which package etc.

1 Like

Hold up. So you are saying it doesn’t even matter what the default CUDA version (nvcc --version) on my system is? I always thought that when you install PyTorch (from the “get started” page), you had to select the Cuda version that is currently installed/the default on your system. But from what you are saying, it seems that that doesn’t even matter and that pytorch installs the correct CUDA libs in the environment? I never realized this! What does that mean for other scripts in the environment though? Will they all use the PyTorch loaded CUDA libs (so does PyTorch set some env specific variables to CUDA so the current env knows where those libs are?).

That is correct, if you are using “pure” PyTorch. To build custom CUDA extensions in PyTorch, you would still need a local CUDA toolkit installation (since you are compiling the extension with nvcc) and I would recommend to install the matching CUDA version. Otherwise your locally compiled CUDA extension would use another CUDA version than the PyTorch runtime and you might run into errors (PyTorch should raise an error if a different local install is detected).
However, if you are just using PyTorch without any extensions and are not rebuilding, you don’t need a local CUDA installation.

Other frameworks and libraries might either ship with their own runtime libs or might need a local installation. This would depend on the 3rd party library and how they decide to create the package.

1 Like

That explains the large file sizes for the CUDA versions. I can’t believe that I hadn’t realized this before! So the only software requirement is that the installed graphics driver supports the CUDA version that you are requesting, right?

I remember that when cloning apex, the CUDA version number had to be identical (major and minor) to the version number PyTorch was built with. So in that scenario, I guess apex used the globally installed CUDA libs, and threw a warning when that differs from the PT ones. What I do not understand, though, is why apex could then not built itself using the CUDA libs that are shipped with PyTorch? Seems that I am misunderstanding something here. Doesn’t PyTorch’s included CUDA include build tools? (I am not very familiar with all subcomponents of CUDA)

Yes, that is correct. You can find the NVIDIA driver requirements for different CUDA versions here.

Yes, that is also correct, since apex was built as an extension, thus you needed to compile it (which is why the “native” amp in torch.cuda.amp can be used directly without rebuilding).

PyTorch ships only with the runtime libs, not with the compiler (nvcc), as this would blow up the size even more. I’m also not sure if you can even ship the compiler in conda binaries or pip wheels.

1 Like

Thanks, this is all very new and super helpful for me. Thanks for taking the time to explain it into detail!!

1 Like

@BramVanroy btw, Do you use nn.DataParallel() or nn.parallel.DistributedDataParallel?
I found that when using nn.DataParallel() torch.amp.GradScaler() not work but nn.parallel.DistributedDataParallel() work well.

I never use DataParallel so I cannot help you with that. I use the Distributed variant.