Questions about QAT

Hello!

I am trying to train MobileNetV3 with Lite Reduced ASPP for Semantic Segmentation using Quantization Aware Training, but for some reason it does not training at all. Output of the model seems to be like random noise.
So I have couple of questions.

  1. Currently I have such activations as nn.ReLU6, nn.Sigmoid, nn.Hardsigmoid and nn.Hardswish. I tried both approaches - wrap them with QuantWrapper or replace with ReLU. Both didn’t help. What is the correct way?
  2. I’ve replaced functional.interpolate with nn.UpsamplingBilinear2d. It also didn’t help. But is it relevant or I can use functional analogue?
  3. I’ve replaced all add and mul operations to separated torch.nn.quantized.FloatFunctional(). Didn’t work.
  4. Usually with QAT approach does model training longer or approximately it should take the same time?
  5. Am I right that is possible to train model with QAT using GPU?

Here is the print of my model:

QuantizableMobileNetV3SmallLRASPP(
  (_quant): QuantStub()
  (_encoder): QuantizableMobileNetV3(
    (_layers): ModuleList(
      (0): Sequential(
        (0): ConvBn2d(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): Identity()
        (2): Hardswish()
      )
      (1): QuantizableInvertedResidual(
        (_expansion): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): ReLU6()
        )
        (_conv): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): ReLU6()
        )
        (_sqse): QuantizableSqueezeAndExcite(
          (_avg_pool): AdaptiveAvgPool2d(output_size=1)
          (_fc): Sequential(
            (0): LinearReLU(
              (0): Linear(in_features=16, out_features=4, bias=True)
              (1): ReLU()
            )
            (1): Identity()
            (2): Linear(in_features=4, out_features=16, bias=True)
            (3): Sequential(
              (0): DeQuantStub()
              (1): Hardsigmoid()
              (2): QuantStub()
            )
          )
          (_mul): FloatFunctional(
            (activation_post_process): Identity()
          )
        )
        (_reduce): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
        )
        (_skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (2): QuantizableInvertedResidual(
        (_expansion): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(16, 72, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): ReLU6()
        )
        (_conv): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(72, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=72, bias=False)
            (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): ReLU6()
        )
        (_reduce): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(72, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
        )
        (_skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (3): QuantizableInvertedResidual(
        (_expansion): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(24, 88, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(88, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): ReLU6()
        )
        (_conv): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(88, 88, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=88, bias=False)
            (1): BatchNorm2d(88, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): ReLU6()
        )
        (_reduce): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(88, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
        )
        (_skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (4): QuantizableInvertedResidual(
        (_expansion): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): Hardswish()
        )
        (_conv): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(96, 96, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=96, bias=False)
            (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): Hardswish()
        )
        (_sqse): QuantizableSqueezeAndExcite(
          (_avg_pool): AdaptiveAvgPool2d(output_size=1)
          (_fc): Sequential(
            (0): LinearReLU(
              (0): Linear(in_features=96, out_features=24, bias=True)
              (1): ReLU()
            )
            (1): Identity()
            (2): Linear(in_features=24, out_features=96, bias=True)
            (3): Sequential(
              (0): DeQuantStub()
              (1): Hardsigmoid()
              (2): QuantStub()
            )
          )
          (_mul): FloatFunctional(
            (activation_post_process): Identity()
          )
        )
        (_reduce): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(96, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
        )
        (_skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (5): QuantizableInvertedResidual(
        (_expansion): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): Hardswish()
        )
        (_conv): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
            (1): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): Hardswish()
        )
        (_sqse): QuantizableSqueezeAndExcite(
          (_avg_pool): AdaptiveAvgPool2d(output_size=1)
          (_fc): Sequential(
            (0): LinearReLU(
              (0): Linear(in_features=240, out_features=60, bias=True)
              (1): ReLU()
            )
            (1): Identity()
            (2): Linear(in_features=60, out_features=240, bias=True)
            (3): Sequential(
              (0): DeQuantStub()
              (1): Hardsigmoid()
              (2): QuantStub()
            )
          )
          (_mul): FloatFunctional(
            (activation_post_process): Identity()
          )
        )
        (_reduce): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
        )
        (_skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (6): QuantizableInvertedResidual(
        (_expansion): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): Hardswish()
        )
        (_conv): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=240, bias=False)
            (1): BatchNorm2d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): Hardswish()
        )
        (_sqse): QuantizableSqueezeAndExcite(
          (_avg_pool): AdaptiveAvgPool2d(output_size=1)
          (_fc): Sequential(
            (0): LinearReLU(
              (0): Linear(in_features=240, out_features=60, bias=True)
              (1): ReLU()
            )
            (1): Identity()
            (2): Linear(in_features=60, out_features=240, bias=True)
            (3): Sequential(
              (0): DeQuantStub()
              (1): Hardsigmoid()
              (2): QuantStub()
            )
          )
          (_mul): FloatFunctional(
            (activation_post_process): Identity()
          )
        )
        (_reduce): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
        )
        (_skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
      (7): QuantizableInvertedResidual(
        (_expansion): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(40, 120, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): Hardswish()
        )
        (_conv): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(120, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=120, bias=False)
            (1): BatchNorm2d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
          (2): Hardswish()
        )
        (_sqse): QuantizableSqueezeAndExcite(
          (_avg_pool): AdaptiveAvgPool2d(output_size=1)
          (_fc): Sequential(
            (0): LinearReLU(
              (0): Linear(in_features=120, out_features=30, bias=True)
              (1): ReLU()
            )
            (1): Identity()
            (2): Linear(in_features=30, out_features=120, bias=True)
            (3): Sequential(
              (0): DeQuantStub()
              (1): Hardsigmoid()
              (2): QuantStub()
            )
          )
          (_mul): FloatFunctional(
            (activation_post_process): Identity()
          )
        )
        (_reduce): Sequential(
          (0): ConvBn2d(
            (0): Conv2d(120, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): Identity()
        )
        (_skip_add): FloatFunctional(
          (activation_post_process): Identity()
        )
      )
    )
  )
  (_decoder): QuantizableLRASPP(
    (_aspp_conv1): Sequential(
      (0): ConvBnReLU2d(
        (0): Conv2d(48, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (1): Identity()
      (2): Identity()
    )
    (_upper2): UpsamplingBilinear2d(size=(49, 78), mode=bilinear)
    (_aspp_conv2): Sequential(
      (0): AvgPool2d(kernel_size=11, stride=(4, 4), padding=0)
      (1): Conv2d(48, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): Sequential(
        (0): DeQuantStub()
        (1): Sigmoid()
        (2): QuantStub()
      )
    )
    (_aspp_conv12): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (_upper12): UpsamplingBilinear2d(size=(98, 155), mode=bilinear)
    (_aspp_conv3): Conv2d(24, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (_upper123): UpsamplingBilinear2d(size=(780, 1240), mode=bilinear)
    (_mul): FloatFunctional(
      (activation_post_process): Identity()
    )
    (_add): FloatFunctional(
      (activation_post_process): Identity()
    )
  )
  (_dequant): DeQuantStub()
)

After model I am also performing nn.Sigmoid activation for binary pixels classification.

If you can help me with piece of advice don’t hesitate to reply.

Thanks in advance!

@smivv, could you share the code you used to enable QAT on the model?

  1. We currently do support quantization of Sigmoid, Hardsigmoid and ReLU6. cc @jerryzh168 to confirm
  2. QAT of UpsamplingBilinear2d isn’t supported so you will have to wrap it with Quant-Dequant block.
  3. Maybe looking at the code will help, we also have a tutorial here https://github.com/pytorch/vision/blob/master/references/classification/train_quantization.py for reference
  4. QAT training takes longer, due to the insertion of observers and fake quant modules in the model
  5. It is possible to do QAT on GPU, but you will need to move the model to CPU before running convert.

Hello @supriyar,

Thanks for your comment!

  1. Are you sure about Sigmoid and Hardsigmoid? Because when I look into models’ print after prepate_qat and I see that the only activations without FakeQuantizer are Sigmoid and Hardsigmoid. While HardSwish has FakeQuantizer attached:
  (2): Hardswish(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8),            quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1,         scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
........................................................
  (2): Sequential(
    (0): DeQuantStub()
    (1): Sigmoid()
    (2): QuantStub(
      (activation_post_process): FakeQuantize(
        fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8),            quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1,         scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
  )
  1. Just did it and it didn’t help.
  2. Already seen that, thanks, but nothing new in terms of code for me :sweat:
  3. Does it take longer in terms of epochs? Because timing for sure is increased.
  4. Ok, that’s how I do that.

I am using Catalyst framework, so what I am doing is that I am preparing model for QAT once stage is started. Here is my code:

class QuantizationAwareTrainingCallback(Callback):
    def __init__(
        self,
        backend: str = "fbgemm",
    ):
        super().__init__(order=CallbackOrder.Internal + 1)

        assert backend in ["fbgemm", "qnnpack"], "Unknown backend type"

        self.backend = backend

    def on_stage_start(self, runner: "IRunner") -> None:
        # model is already fused
        runner.model.train()

        torch.backends.quantized.engine = self.backend
        runner.model.qconfig = torch.quantization.get_default_qat_qconfig(self.backend)
        runner.model = torch.quantization.prepare_qat(runner.model, inplace=False)

        runner.model.apply(torch.quantization.enable_observer)
        runner.model.apply(torch.quantization.enable_fake_quant)

Finally I made it working.

The real reason why it wasn’t working is because I had PyTorch 1.7 installed which is probably has not all quantized operations implemented or maybe has some bugs.

Now I installed PyTorch from sources using master branch and I have it training :slight_smile:

Thanks!

2 Likes