Convert_syncbn_model causes gradient overflow with apex mixed precision

Hi,
I was trying to train my network using apex mixed precision. I’ve tried DenseNet and Resnet as backbones for a segmentation task using CityScapes. Unfortunately, when I try to synchronize the batch norm using convert_syncbn_model, the scale_loss ends up being zero after a few iterations because of gradient overflow. This does not happen if I remove the batch normalization.
The snipet of my code is the following:

model.cuda(gpu)
model = apex.parallel.convert_syncbn_model(model)
optimizer = optim.Adam(model.parameters())
model, optimizer = apex.amp.initialize(model, optimizer)
#model = apex.parallel.convert_syncbn_model(model) #I also tried to put it here
net = DDP(model, delay_allreduce=True)

loss = Cross_entropy(y_pred, y_gt)
with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
optimizer.step()
optimizer.zero_grad()

System:
OS: Ubuntu 16.04 and 18.04
Pytorch: tried with 1.4, 1.5 and .1.6
apex: 0.1

Did anyone experience the same?
Thanks

Mixed precision training utilities as well as synchronized batchnorm layers are now available in PyTorch directly, so you don’t need apex anymore.
We recommend to use these native implementations now. :slight_smile:
Could you try them and see, if you encounter any issues?

I’ve replaced the code with the following cuda.amp code:

scaler = torch.cuda.amp.GradScaler()

for epoch in range(n_epochs):
    for i, (X_batch, y_gt) in enumerate(batches):
        X_batch = X_batch.cuda()
        y_gt = y_gt.cuda()
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            y_pred = model(X_batch)
            loss = Cross_entropy(y_pred, y_gt)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

It is training stably so far but it is really slow. I’m using Densenet121 backbone and some convolutions to output 19 segmentation classes with a batch size of 2 with 2 GPUs GTX1080Ti. The input size is 512x512. When I don’t use cuda.amp the training speed per iteration is ~2 seconds involving forward and backward passes. However, if I use cuda.amp, the speed per iteration is ~6.6 seconds.

System:
OS: 18.04
Pytorch: 1.6 (1.6.0+cu101)
Cuda: 10.1 (cannot get cuda 10.2 because I don’t have superuser access)
GPU: 2 GTX 1080Ti

1 Like

Your 1080Ti GPUs do not have TensorCores, so that you shouldn’t expect a speedup from the computations. However, how did you measure the performance? Did you synchronize the code properly using torch.cuda.synchronize() before starting and stopping the timers?

1 Like

I measure the performance following the post bellow. Instead of while I have the forward, backward, optimizer step and zeroing the parameter gradients, then I print the time each iteration takes.

I can confirm that the training did not face any problems yet, apart from the slow training :confused:.
Do you think the cuda version might be causing this? Is there any verbose I can enable for debugging, or if it can print something like apex when there is gradient overflow and is adjusting the scale loss?

Thanks for the help :slight_smile:.

Thanks for the information.
I assume the slowdown can be observed on a single device and would thus be unrelated to SyncBatchnorm.
Could you install the latest stable PyTorch release (or the nightly binaries) with CUDA10.2 and cudnn7.6.5.32?
Note that the binaries ship with their own CUDA and cudnn rumtimes. Your local installations won’t be used so you don’t need to update the local CUDA version.

Thanks,
I have downloaded the latest PyTorch with CUDA 10.2 using (pip install torch torchvision). To reduce possible errors, I tested it in a machine with CUDA 10.2 as well. Unfortunately, the training time remains slow when using torch.cuda.amp.

System:
Ubuntu 18.04
Cuda: 10.2
Pytorch: 1.6
GPU: 2 GTX 1080Ti

Thanks for the check.
Are you seeing the slowdown using the torchvision.models.densenet121 or are you changing it somehow?
I would like to reproduce it with cudnn7 as well as cudnn8 and check, why amp is slowing down the model.

1 Like

Hi, I modified the original PyTorch densenet121 to extract features at different scales. Here is the modified densenet; it still can load the weights provided by PyTorch. I cannot send you our exact decoder, but I guess any SegNet-like architecture for pixel-wise segmentation should show the same behavior.

To check if the error was caused by my code, I used this repo, which is a deeplab v3 implementation with pretrained weights for cityscapes. I tested torch.cuda.amp, O1 from apex.amp, and no mixed precision (the last one was tested with half of the size of the input image used by the mixed precision tests).

I’ve run some epochs and found that torch.cuda.amp is twice as slow as apex.amp, and no mixed precision. In numbers, 1 iteration (forward and backward) using torch.cuda.amp takes ~2.50 seconds, the other two take around ~1.20 seconds.

Configuration:
batch size: 2 per GPU
Image size: 512x512 (mixed precision), 256x512 (no mixed precision)
Batch sync: Yes

Dataset: Cityscapes semantic segmentation
Code: https://github.com/nyoki-mtl/pytorch-segmentation
optimizer: adam
loss: cross entropy

System:
Ubuntu 18.04
Cuda: 10.2
CudNN: 8.0.3
Pytorch: 1.6
GPU: 2 GTX 1080Ti

If you need more details please let me know. Thanks for the help :slight_smile:.


p.s.
Just to update more information about Apex. I’m still getting gradient overflow and scale_loss being reduce to zero when using apex.parallel.convert_syncbn_model() after few epochs.

Thanks @ptrblck for your help. I manage to get access to a pair of RTX2080Ti. Here the speed of torch.amp is similar to apex. I guess torch.amp is optimized to use TensorCores to boost the speed, given that GTX GPUs do not have them, the computation time is slower. Could you confirm this?

The model execution shouldn’t be slower on Pascal GPUs but you also shouldn’t expect significant speedup. The model execution should benefit from the reduced memory transfers, which could show a speedup even on older GPUs.

Are you seeing the slowdown on the 1080Ti even with torch.backends.cudnn.benchmark=True?

Yes, even with torch.backends.cudnn.benchmark=True placed at the beginning of the script it is slow using a GTX1080Ti. On the other hand, It is around 3 days of training using torch.cuda.amp with a pair of RTX2080Ti’s without any problems. This reinforces more my theory about how torch.cuda.amp is optimized.