Quantization awareness training multi-gpu suport?

Does pytorch support multi-GPU in quantization awareness training?
In this script https://github.com/pytorch/vision/blob/master/references/classification/train_quantization.py#L73, it seems that it has the logic of multi-GPU.

Hi @robotcator123,
Multi gpu training is orthogonal to quantization aware training. Code written with Pytorch’s quantization aware training modules will work whether you are using a single gpu or using Data parallel on multiple gpus. Hope this helps!

1 Like

Hi, @Mazhar_Shaikh

Actually, I train the mobilenet using this command in cluster.
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --data-path=./imagenet_1k , it seems that the code works fine.

But I change the training script to
python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --data-path=./imagenet_1k

it will raise error like this after print lost of unknown data:
`Namespace(backend=‘qnnpack’, batch_size=32, cache_dataset=False, data_path=’~/test/imagenet_1k’, device=‘cuda’, dist_backend=‘nccl’, dist_url=‘env://’, distributed=True, epochs=90, eval_batch_size=128, gpu=0, lr=0.0001, lr_gamma=0.1, lr_step_size=30, model=‘mobilenet_v2’, momentum=0.9, num_batch_norm_update_epochs=3, num_calibration_batches=32, num_observer_update_epochs=4, output_dir=’.’, post_training_quantize=False, print_freq=10, rank=0, resume=’’, start_epoch=0, test_only=False, weight_decay=0.0001, workers=16, world_size=8)
Loading data
Loading data
Loading training data
Took 0.27007627487182617
Loading validation data
Creating data loaders
Creating model mobilenet_v2
Traceback (most recent call last):
File “train_quantization.py”, line 258, in
main(args)
File “train_quantization.py”, line 77, in main
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
File “xxx/.conda/envs/pytorch1.3/lib/python3.6/site-packages/torch/nn/parallel/distributed.py”, line 298, in init
self.broadcast_bucket_size)
File “xxx/.conda/envs/pytorch1.3/lib/python3.6/site-packages/torch/nn/parallel/distributed.py”, line 480, in _distributed_broadcast_coalesced
dist._broadcast_coalesced(self.process_group, tensors, buffer_size)
TypeError: _broadcast_coalesced(): incompatible function arguments. The following argument types are supported:
1. (process_group: torch.distributed.ProcessGroup, tensors: List[at::Tensor], buffer_size: int) -> None

Invoked with: <torch.distributed.ProcessGroupNCCL object at 0x7f943f78dd18>, [tensor([[[[ 1.3185e-02, -4.3213e-03, 1.4823e-02],



subprocess.CalledProcessError: Command ‘[’/xxxx/pytorch1.3/bin/python’, ‘-u’, ‘train_quantization.py’, ‘–data-path=./imagenet_1k’]’ returned non-zero exit status 1.`
sorry for hiding some personal information.

It seems that broadcast can not support None tensor, is there anybody know this problem?

Hi @robotcator123, If you believe broadcast of None doesn’t work as expected, please open an issue against PyTorch with a minimal reproducible example.

Hi, @dskhudia, Thank you for your response. I can give some reproduction steps. For more detailed producible example, maybe I will do it at weekends.

1: Download the imagenet1k dataset.
2: pip install torchvision==0.5.0, this will upgrade the torch into 1.4.0.
3: Use the script with the commands
python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quant.py --data-path=./imagenet_1k

The train_quant.py script is borrowed from torchvision reference code.

Summary
from __future__ import print_function
import datetime
import os
import time
import sys
import copy

import torch
import torch.utils.data
from torch import nn
import torchvision
import torch.quantization
import train_utils as utils
from train import train_one_epoch, evaluate, load_data

def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)

    print(args)

    if args.post_training_quantize and args.distributed:
        raise RuntimeError("Post training quantization example should not be performed "
                           "on distributed mode")

    # Set backend engine to ensure that quantized model runs on the correct kernels
    if args.backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported: " + str(args.backend))
    torch.backends.quantized.engine = args.backend

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
    train_dir = os.path.join(args.data_path, 'train')
    val_dir = os.path.join(args.data_path, 'val')

    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
                                                                   args.cache_dataset, args.distributed)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=train_sampler, num_workers=args.workers, pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.eval_batch_size,
        sampler=test_sampler, num_workers=args.workers, pin_memory=True)

    print("Creating model", args.model)
    # when training quantized models, we always start from a pre-trained fp32 reference model
    model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
    model.to(device)

    if not (args.test_only or args.post_training_quantize):
        model.fuse_model()
        model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
        torch.quantization.prepare_qat(model, inplace=True)

        optimizer = torch.optim.SGD(
            model.parameters(), lr=args.lr, momentum=args.momentum,
            weight_decay=args.weight_decay)

        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=args.lr_step_size,
                                                       gamma=args.lr_gamma)

    criterion = nn.CrossEntropyLoss()
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
        print (model.module)

    model.apply(torch.quantization.enable_observer)
    model.apply(torch.quantization.enable_fake_quant)
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        print('Starting training for epoch', epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
                        args.print_freq)
        lr_scheduler.step()
        with torch.no_grad():
            if epoch >= args.num_observer_update_epochs:
                print('Disabling observer for subseq epochs, epoch = ', epoch)
                model.apply(torch.quantization.disable_observer)
            if epoch >= args.num_batch_norm_update_epochs:
                print('Freezing BN for subseq epochs, epoch = ', epoch)
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
            print('Evaluate QAT model')

            evaluate(model, criterion, data_loader_test, device=device)
            quantized_eval_model = copy.deepcopy(model)
            quantized_eval_model.eval()
            quantized_eval_model.to(torch.device('cpu'))
            torch.quantization.convert(quantized_eval_model, inplace=True)

            print('Evaluate Quantized model')
            evaluate(quantized_eval_model, criterion, data_loader_test,
                     device=torch.device('cpu'))

        model.train()

        print('Saving models after epoch ', epoch)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

def parse_args():                                                                                                                                                                                   
    import argparse
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')

    parser.add_argument('--data-path',
                        default='/datasets01/imagenet_full_size/061417/',
                        help='dataset')
    parser.add_argument('--model',
                        default='mobilenet_v2',
                        help='model')
    parser.add_argument('--backend',
                        default='qnnpack',
                        help='fbgemm or qnnpack')
    parser.add_argument('--device',
                        default='cuda',
                        help='device')

    parser.add_argument('-b', '--batch-size', default=32, type=int,
                        help='batch size for calibration/training')
    parser.add_argument('--eval-batch-size', default=128, type=int,
                        help='batch size for evaluation')
    parser.add_argument('--epochs', default=90, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--num-observer-update-epochs',
                        default=4, type=int, metavar='N',
                        help='number of total epochs to update observers')
    parser.add_argument('--num-batch-norm-update-epochs', default=3,
                        type=int, metavar='N',
                        help='number of total epochs to update batch norm stats')
    parser.add_argument('--num-calibration-batches',
                        default=32, type=int, metavar='N',
                        help='number of batches of training set for \
                              observer calibration ')

    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr',
                        default=0.0001, type=float,
                        help='initial learning rate')
    parser.add_argument('--momentum',
                        default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 16)')
    parser.add_argument('--lr',
                        default=0.0001, type=float,
                        help='initial learning rate')
    parser.add_argument('--momentum',
                        default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('--lr-step-size', default=30, type=int,
                        help='decrease lr every step-size epochs')
    parser.add_argument('--lr-gamma', default=0.1, type=float,
                        help='decrease lr by a factor of lr-gamma')
    parser.add_argument('--print-freq', default=10, type=int,
                        help='print frequency')
    parser.add_argument('--output-dir', default='.', help='path where to save')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')

    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. \
             It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--post-training-quantize",
        dest="post_training_quantize",
        help="Post training quantize the model",
        action="store_true",
    )
    # distributed training parameters
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist-url',
                        default='env://',
                        help='url used to set up distributed training')

    args = parser.parse_args()

    return args

if __name__ == "__main__":
    args = parse_args()
    main(args)