Training model seems not to effect using distributed training with auto mixed-precision

Hi everyone. I’m gonna to use Nvidia Apex package to fast train my model with the help of auto mixed-precision. However even if the the loss continues to drop, the model inference dose not achieve improvement. My training code is as follows:

import os
import argparse
import time
import tqdm
import cv2
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from config.config import GetConfig, COCOSourceConfig, TrainingOpt
from data.mydataset import MyDataset
from torch.utils.data import DataLoader
from models.posenet import Network
from models.loss_model import MultiTaskLoss
import warnings

try:
    from apex.parallel import DistributedDataParallel as DDP
    from apex.fp16_utils import *
    from apex import amp
    from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
    raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")


warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser(description='PoseNet Training')
parser.add_argument('--resume', '-r', action='store_true', default=True, help='resume from checkpoint')
parser.add_argument('--checkpoint_path', '-p',  default='link2checkpoints_distributed', help='save path')
parser.add_argument('--max_grad_norm', default=5, type=float,
                    help=("If the norm of the gradient vector exceeds this, "
                          "re-normalize it to have the norm equal to max_grad_norm"))
# FOR DISTRIBUTED:  Parse for the local_rank argument, which will be supplied automatically by torch.distributed.launch.
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--opt-level', type=str, default='O1')
parser.add_argument('--sync_bn',  action='store_true', default=False, help='enabling apex sync BN.')  
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--print-freq', '-f', default=10, type=int, metavar='N', help='print frequency (default: 10)')

torch.backends.cudnn.benchmark = True  
use_cuda = torch.cuda.is_available()

args = parser.parse_args()

checkpoint_path = args.checkpoint_path
opt = TrainingOpt()
config = GetConfig(opt.config_name)
soureconfig = COCOSourceConfig(opt.hdf5_train_data)
train_data = MyDataset(config, soureconfig, shuffle=False, augment=True)  # shuffle in data loader

soureconfig_val = COCOSourceConfig(opt.hdf5_val_data)
val_data = MyDataset(config, soureconfig_val, shuffle=False, augment=True)  # shuffle in data loader


best_loss = float('inf')
start_epoch = 0  
args.distributed = False
if 'WORLD_SIZE' in os.environ:
    args.distributed = int(os.environ['WORLD_SIZE']) > 1

args.gpu = 0
args.world_size = 1

# FOR DISTRIBUTED:  If we are running under torch.distributed.launch,
# the 'WORLD_SIZE' environment variable will also be set automatically.
if args.distributed:
    args.gpu = args.local_rank
    torch.cuda.set_device(args.gpu)
    # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.world_size = torch.distributed.get_world_size()  # 获取分布式训练的进程数

assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

posenet = Network(opt, config, dist=True, bn=False)
# Actual working batch size on multi-GPUs is 4 times bigger than that on one GPU
# fixme: add up momentum if the batch grows?
optimizer = optim.SGD(posenet.parameters(), lr=opt.learning_rate * args.world_size, momentum=0.9, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2, last_epoch=-1)


if args.sync_bn:
    #  This should be done before model = DDP(model, delay_allreduce=True),
    #  because DDP needs to see the finalized model parameters
    # We rely on torch distributed for synchronization between processes. Only DDP support the apex sync_bn now.
    import apex
    print("Using apex synced BN.")
    posenet = apex.parallel.convert_syncbn_model(posenet)

posenet.cuda()

# Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize.
model, optimizer = amp.initialize(posenet, optimizer,
                                  opt_level=args.opt_level,
                                  keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                                  loss_scale=args.loss_scale)  # Dynamic loss scaling is used by default.
# delay_allreduce delays all communication to the end of the backward pass.


if args.distributed:
    # By default, apex.parallel.DistributedDataParallel overlaps communication with computation in the backward pass.
    # model = DDP(model)
    # delay_allreduce delays all communication to the end of the backward pass.
    model = DDP(model, delay_allreduce=True)



train_sampler = None
val_sampler = None
# Restricts data loading to a subset of the dataset exclusive to the current process
# Create DistributedSampler to handle distributing the dataset across nodes when training 
# This can only be called after distributed.init_process_group is called

if args.distributed:
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, shuffle=(train_sampler is None),
                                           num_workers=16, pin_memory=True, sampler=train_sampler, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=opt.batch_size, shuffle=False,
                                         num_workers=4, pin_memory=True, sampler=val_sampler, drop_last=True)

for param in model.parameters():
    if param.requires_grad:
        print('Parameters of network: Autograd')
        break


#  Update the learning rate for start_epoch times
for i in range(start_epoch):
    scheduler.step()


def train(epoch):
    print('\n ############################# Train phase, Epoch: {} #############################'.format(epoch))
    posenet.train()

    if args.distributed:
        train_sampler.set_epoch(epoch)
    # train_loss = 0
    scheduler.step()
    print('\nLearning rate at this epoch is: %0.9f\n' % optimizer.param_groups[0]['lr'])  # scheduler.get_lr()[0]

    batch_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    for batch_idx, target_tuple in enumerate(train_loader):
        # images.requires_grad_()
        # loc_targets.requires_grad_()
        # conf_targets.requires_grad_()
        if use_cuda:
      
            target_tuple = [target_tensor.cuda(non_blocking=True) for target_tensor in target_tuple]

        # target tensor shape: [8,512,512,3], [8, 1, 128,128], [8,43,128,128], [8,36,128,128], [8,36,128,128]
        images, mask_misses, heatmaps = target_tuple  # , offsets, mask_offsets
        # images = Variable(images)
        # loc_targets = Variable(loc_targets)
        # conf_targets = Variable(conf_targets)

        loss = model(images, target_tuple[1:])
        optimizer.zero_grad()  # zero the gradient buff

        if loss.item() > 1e6:
            print("\nLoss is abnormal, drop this batch !")
            loss.zero_()
            continue

        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()

        torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm)
        optimizer.step()  

     
        if batch_idx % args.print_freq == 0:
            # Every print_freq iterations, check the loss, accuracy, and speed.
            # For best performance, it doesn't make sense to print these metrics every
            # iteration, since they incur an allreduce and some host<->device syncs.
         
            if args.distributed:
                # We manually reduce and average the metrics across processes. In-place reduce tensor.
                reduced_loss = reduce_tensor(loss.data)
            else:
                reduced_loss = loss.data

            # to_python_float incurs a host<->device sync
            losses.update(to_python_float(reduced_loss), images.size(0))  # update needs average and number
            torch.cuda.synchronize()
            batch_time.update((time.time() - end) / args.print_freq)
            end = time.time()

            if args.local_rank == 0:  # Print them in the Process 0
                print('==================> Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Speed {3:.3f} ({4:.3f})\t'
                      'Loss {loss.val:.10f} ({loss.avg:.4f}) <================ \t'.format(
                        epoch, batch_idx, len(train_loader),
                        args.world_size * opt.batch_size / batch_time.val,
                        args.world_size * opt.batch_size / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses))

    global best_loss

    # train_loss /= (len(train_loader))  # Each GPU process can only see 1/(world_size) training samples per epoch

    if args.local_rank == 0:
        # Write the log file each epoch.
        os.makedirs(checkpoint_path, exist_ok=True)
        logger = open(os.path.join('./' + checkpoint_path, 'log'), 'a+')
        logger.write('\nEpoch {}\ttrain_loss: {}'.format(epoch, losses.avg))  
        logger.flush()
        logger.close()

        if losses.avg < best_loss:
            # Update the best_loss if the average loss drops
            best_loss = losses.avg
            print('Saving model checkpoint...')
            state = {
                # not posenet.state_dict(). then, we don't ge the "module" string to begin with
                'weights': model.module.state_dict(),
                'optimizer_weight': optimizer.state_dict(),
                'train_loss': losses.avg,
                'epoch': epoch
            }
            torch.save(state, './' + checkpoint_path + '/PoseNet_' + str(epoch) + '_epoch.pth')


def test(epoch):
    print('\n ############################# Test phase, Epoch: {} #############################'.format(epoch))
    posenet.eval()
  
    if args.distributed:
        train_sampler.set_epoch(epoch)  
    batch_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    for batch_idx, target_tuple in enumerate(val_loader):
        # images.requires_grad_()
        # loc_targets.requires_grad_()
        # conf_targets.requires_grad_()
        if use_cuda:
    
            target_tuple = [target_tensor.cuda(non_blocking=True) for target_tensor in target_tuple]

        # target tensor shape: [8,512,512,3], [8, 1, 128,128], [8,43,128,128], [8,36,128,128], [8,36,128,128]
        images, mask_misses, heatmaps = target_tuple  # , offsets, mask_offsets

        with torch.no_grad():
            _, loss = model(images, target_tuple[1:])

        if args.distributed:
            # We manually reduce and average the metrics across processes. In-place reduce tensor.
            reduced_loss = reduce_tensor(loss.data)
        else:
            reduced_loss = loss.data

        # to_python_float incurs a host<->device sync
        losses.update(to_python_float(reduced_loss), images.size(0))  # update needs average and number
        torch.cuda.synchronize() 
        batch_time.update((time.time() - end))
        end = time.time()

        if args.local_rank == 0:  # Print them in the Process 0
            print('==================>Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Speed {2:.3f} ({3:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                   batch_idx, len(val_loader),
                   args.world_size * opt.batch_size / batch_time.val,
                   args.world_size * opt.batch_size / batch_time.avg,
                   batch_time=batch_time, loss=losses))

    if args.local_rank == 0:  # Print them in the Process 0
        # Write the log file each epoch.
        os.makedirs(checkpoint_path, exist_ok=True)
        logger = open(os.path.join('./' + checkpoint_path, 'log'), 'a+')
        logger.write('\tval_loss: {}'.format(losses.avg)) 
        logger.flush()
        logger.close()


def adjust_learning_rate(optimizer, epoch, step, len_epoch):
    """LR schedule that should yield 76% converged accuracy with batch size 256"""
    factor = epoch // 30

    if epoch >= 80:
        factor = factor + 1

    lr = args.lr*(0.1**factor)

    """Warmup"""
    if epoch < 5:
        lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)  # len_epoch=len(train_loader)

    # if(args.local_rank == 0):
    #     print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def reduce_tensor(tensor):
    # Reduces the tensor data across all machines
    # If we print the tensor, we can get:
    # tensor(334.4330, device='cuda:1') *********************, here is cuda:  cuda:1
    # tensor(359.1895, device='cuda:3') *********************, here is cuda:  cuda:3
    # tensor(263.3543, device='cuda:2') *********************, here is cuda:  cuda:2
    # tensor(340.1970, device='cuda:0') *********************, here is cuda:  cuda:0
    rt = tensor.clone()  # The function operates in-place.
    dist.all_reduce(rt, op=dist.reduce_op.SUM)
    rt /= args.world_size
    return rt


if __name__ == '__main__':
    for epoch in range(start_epoch, start_epoch + 80):
        train(epoch)
        test(epoch)


1 Like

To be more specific, I have followed the ImageNet example in Nvidia Apex. I write the loss function inside my Network module which is like the follows:

class Network(torch.nn.Module):
    """
    Wrap the network module as well as the loss module on all GPUs to balance the computation among GPUs.
    """
    def __init__(self, opt, config, bn=False, dist=False):
        super(Network, self).__init__()
        self.posenet = PoseNet(opt.nstack, opt.hourglass_inp_dim, config.num_layers, bn=bn)
        # If we use train_parallel, we implement the parallel loss . And if we use train_distributed,
        # we should use single process loss because each process on these 4 GPUs  is independent
        self.criterion = MultiTaskLoss(opt, config) if dist else MultiTaskLossParallel(opt, config)

    def forward(self, inp_imgs, target_tuple):
        # Batch will be divided and Parallel Model will call this forward on every GPU
        output_tuple = self.posenet(inp_imgs)
        loss = self.criterion(output_tuple, target_tuple)

        if not self.training:
            # output will be concatenated  along batch channel automatically after the parallel model return
            return output_tuple, loss
        else:
            # output will be concatenated  along batch channel automatically after the parallel model return
            return loss

The training loss seems normal:

Epoch 0 train_loss: 589.6713480631511 val_loss: 536.4533081054688
Epoch 1 train_loss: 446.2322041829427 val_loss: 440.89935302734375
Epoch 2 train_loss: 436.07487325032554 val_loss: 433.20953369140625
Epoch 3 train_loss: 433.3325126139323 val_loss: 396.94744873046875
Epoch 4 train_loss: 425.1072373453776 val_loss: 406.3310546875
Epoch 5 train_loss: 418.57773783365883 val_loss: 392.5045166015625
Epoch 6 train_loss: 409.60796936035155 val_loss: 419.2001037597656
Epoch 7 train_loss: 410.79097737630207 val_loss: 409.8291320800781
Epoch 8 train_loss: 404.4842706298828 val_loss: 407.05352783203125
Epoch 9 train_loss: 399.4785394287109 val_loss: 388.7215881347656
Epoch 10 train_loss: 389.387607421875 val_loss: 379.6018981933594
Epoch 11 train_loss: 386.5943516031901 val_loss: 397.2137451171875
Epoch 12 train_loss: 382.25890686035154 val_loss: 376.7177734375
Epoch 13 train_loss: 387.2037613932292 val_loss: 360.4934387207031
Epoch 14 train_loss: 379.99100199381513 val_loss: 377.1543884277344
Epoch 15 train_loss: 381.0046073404948 val_loss: 378.36041259765625
Epoch 16 train_loss: 378.6185076904297 val_loss: 365.29205322265625
Epoch 17 train_loss: 380.5766967773437 val_loss: 364.39569091796875
Epoch 18 train_loss: 382.2865834554037 val_loss: 368.50152587890625

But the model seems not to have been trained well and the prediction results refuse to get better (which is bad actually).
I have struggled with this problem for a while. If I don’t use distributed training or Apex auto mixed-precision and I only wrap my Network module with torch.nn.parallel.DataParallel, everything goes fine and the prediction is good.

Numerical issues are notoriously hard to debug.

Can you isolate the issue to either distributed or mixed precision?

@mcarilli Any ideas?

This may well be an Apex bug. About a week ago, for a few days the combination of dynamic loss scaling + Apex DDP was broken in Apex master. I fixed it in https://github.com/NVIDIA/apex/commit/8437d29505fcc7fad28183395abd89a09a17efe6, so maybe a fresh clone + reinstall of Apex will resolve the issue. Be sure to clean the old install before rebuilding:
pip uninstall apex
cd apex_repo_dir
rm -rf build (if present)
rm -rf apex.egg-info (if present)
git pull
pip install -v --no-cache-dir --global-option="–cpp_ext" --global-option="–cuda_ext" .

2 Likes

Thank you for your reply. The problem has not yet solved.

If I remove the clip_norm in the training step, the gradient will explode after some batches. The training process looks like okay before explosion. No norm operation is used in my case. All input tensors and ground truth tensors are normalized into [0,1]. L2 loss and weight_decay are used. I have no idea which detail should I concentrate on.

Did you try a fresh clone and install of Apex?

Gradient clipping does require special treatment for compatibility with all opt_levels: https://nvidia.github.io/apex/advanced.html#gradient-clipping

Yes, I have followed your instruction to reinstall Apex. My problem is strange. I printed the abnormal value during the distributed training:
First, my model has various scales of feature map predicted (cascaded CNN and in each stage/stack has 5 scale output)

This is normal for some batches, and the element-wise output should in the range of [0,1] (gaussian heatmap regression)

heatmap L2 loss per stack.........   [ 69848.01   59730.246 223546.12   60869.35 ]
 heatmap L2 loss per stack.........   [15058.608 13271.5   13770.041 13684.25 ]
 heatmap L2 loss per stack.........   [3515.2559 3062.7026 2899.563  2879.3105]
 heatmap L2 loss per stack.........   [ 84485.94  76283.11 219553.47  77723.48]
 heatmap L2 loss per stack.........   [ 70769.086  63346.633 209632.16   64268.496]
 heatmap L2 loss per stack.........   [18312.457 17451.66  17986.875 17935.975]

However, the loss is abnormal suddenly , and the elements of the output become very large such as 223, and then grow rapidly and resulting in gradient explosion.

Dangerous! Check pred, gt, mask_miss: ======>  tensor(223.1250, device='cuda:2', dtype=torch.float16, grad_fn=<MaxBackward1>) tensor(1., device='cuda:2') tensor(1., device='cuda:2')
 heatmap L2 loss per stack.........   [0. 0. 0. 0.]
Dangerous! Check pred, gt, mask_miss: ======>  tensor(223.1250, device='cuda:1', dtype=torch.float16, grad_fn=<MaxBackward1>) tensor(1., device='cuda:1') tensor(1., device='cuda:1')
Dangerous! Check pred, gt, mask_miss: ======>  tensor(222.7500, device='cuda:3', dtype=torch.float16, grad_fn=<MaxBackward1>) tensor(1., device='cuda:3') tensor(1., device='cuda:3')
Dangerous! Check pred, gt, mask_miss: ======>  tensor(222.7500, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>) tensor(1., device='cuda:0') tensor(1., device='cuda:0')

Update. Problem has been solved. I add clamp into loss value and change the weight of multi-scale losses. It seems that the auto loss scale in Apex is not perfect yet.

Would you explain more in your code, I have the same issue.