Quantization Aware Training - Tiny YOLOv3

Hi,

I’m trying to implement Quantization Aware Training as part of my Tiny YOLOv3 model (have mostly used ultralytics/yolov3 as the base for my code). This is what my model architecture looks like:

Model(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv(
      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv(
      (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (11): ZeroPad2d(padding=[0, 1, 0, 1], value=0.0)
    (12): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
    (13): Conv(
      (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(1024, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (14): Conv(
      (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (15): Conv(
      (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (16): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (17): Upsample(scale_factor=2.0, mode=nearest)
    (18): Concat()
    (19): Conv(
      (conv): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.1, inplace=True)
    )
    (20): Detect(
      (m): ModuleList(
        (0): Conv2d(256, 27, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(512, 27, kernel_size=(1, 1), stride=(1, 1))
      )
    )
  )
)

and this is the code snippet that I have where I am trying to implement the QAT (as part of my train.py script):

model.train()

# Set backend to qnnpack
torch.backends.quantized.engine = 'qnnpack'

# QAT: Attach global qconfig
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')

# QAT: Fuse modules
model_fused = torch.quantization.fuse_modules(model, [[model.model[2].conv, model.model[2].bn], 
                                                      [model.model[4].conv, model.model[4].bn], 
                                                      [model.model[6].conv, model.model[6].bn], 
                                                      [model.model[8].conv, model.model[8].bn], 
                                                      [model.model[10].conv, model.model[10].bn], 
                                                      [model.model[13].conv, model.model[13].bn], 
                                                      [model.model[14].conv, model.model[14].bn], 
                                                      [model.model[15].conv, model.model[15].bn], 
                                                      [model.model[16].conv, model.model[16].bn], 
                                                      [model.model[19].conv, model.model[19].bn]], inplace=True)

# QAT: Perform "fake quantization"
model_prepared = torch.quantization.prepare_qat(model_fused, inplace=True)

The issue that I’m facing with this is that in the fuse_modules() method, I’m getting the following error

Traceback (most recent call last):
  File "train.py", line 606, in <module>
    train(hyp, opt, device, tb_writer, wandb)
  File "train.py", line 123, in train
    [model.model[19].conv, model.model[19].bn]], inplace=True)
  File "/azureml-envs/azureml_0ae001c63ee102296c480ee5afc65405/lib/python3.6/site-packages/torch/quantization/fuse_modules.py", line 146, in fuse_modules
    _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
  File "/azureml-envs/azureml_0ae001c63ee102296c480ee5afc65405/lib/python3.6/site-packages/torch/quantization/fuse_modules.py", line 74, in _fuse_modules
    mod_list.append(_get_module(model, item))
  File "/azureml-envs/azureml_0ae001c63ee102296c480ee5afc65405/lib/python3.6/site-packages/torch/quantization/fuse_modules.py", line 15, in _get_module
    tokens = submodule_key.split('.')
  File "/azureml-envs/azureml_0ae001c63ee102296c480ee5afc65405/lib/python3.6/site-packages/torch/nn/modules/module.py", line 948, in __getattr__
    type(self).__name__, name))
AttributeError: 'Conv2d' object has no attribute 'split'

I’ve followed the QAT tutorial on the PyTorch docs and can’t seem to understand why this error is occurring.

Hi,

torch.quantization.fuse_modules(model, list)

Expects list of names of the operations to be fused as the second argument. However, you passed the operations themselves that causes the error. Try to change the second argument to name of your layers which are defined in the init method of your model. A short example:

If your Model is defined like this:

class Net(nn.Module):
    def __init__(self, scale):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU(inplace=True)
        ...

You can fuse layers with the following:

model = torch.quantization.fuse_modules(model, [['conv1', 'relu1'], ["conv2", "relu2"]])

Hope this helps!

1 Like

Have you completed the Quantization Aware Training on tiny yolov3? If yes, can you provide the code? Thank you very much