Prepare_qat on module removes hooks

Hello,

I’ve just started to dive into quantization tools that were introduced in version 1.3.
For now I am trying to train a network with existing model generation code. The code has certain subtleties, one of these are _forward_pre_hooks in several submodules. (See code below)

Here is the problem. After prepare_qat with default config changes submodules to ones with fake quantization and hooks are disappeared. Is it possible to prevent hooks from disappearing during prepare_qat (and submodule.fuse() too)

There is an intermediate code:

...
print('pre qat: ', model.backbone.bottom_up.blocks[2][3].conv_pwl)
print('pre qat: ', model.backbone.bottom_up.blocks[2][3].conv_pwl._forward_pre_hooks.values())
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
torch.quantization.prepare_qat(model, inplace=True)
print('post qat: ', model.backbone.bottom_up.blocks[2][3].conv_pwl)
print('post qat: ', model.backbone.bottom_up.blocks[2][3].conv_pwl._forward_pre_hooks.values())
...

And the output is:

pre qat:  Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)                                                                                                                                    
pre qat:  odict_values([functools.partial(<bound method FeatureHooks._collect_output_hook of <timm.models.feature_hooks.FeatureHooks object at 0x7fd3b5bcf5d0>>, 'blocks.2.3.conv_pwl')])
post qat:  Conv2d(
  120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False
  (activation_post_process): FakeQuantize(
    fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.]), zero_point=tensor([0])
    (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (weight_fake_quant): FakeQuantize(
    fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.]), zero_point=tensor([0])
    (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
)
post qat:  odict_values([])

As I can to prevent hooks from disappearing it is needed to put relevant code somewhere here (during prepare_qat):

But during fusion it is not that obvious because we map up to three modules to one and each of them could have hooks. I think it is possible to work around with torch.quantization.fuse_modules(...fuser_func=<func>...)

Yeah feel free to submit a PR for preserving the pre hooks in swapping.

For fusion I think we probably need to error out if you have a prehook for intermediate moduels like BatchNorm because when we fuse batchnorm into conv, batchnorm is gone.

I’ll try!

So Jerry could you explain please for what purpose that was done?

After prepare() is called convert() will remove all hooks.
I think there’s a reason to create such hook than to remove it should we preserve all pre forward / post forward hooks except this one?

Yeah I think we should preserve it, but we need to consider each case carefully, since this is interfering with the quantization. That is, when do we run pre forward hook and post forward hook, do we run it before/after observation/quantization?