Prepare_qat on module removes hooks

Hello,

I’ve just started to dive into quantization tools that were introduced in version 1.3.
For now I am trying to train a network with existing model generation code. The code has certain subtleties, one of these are _forward_pre_hooks in several submodules. (See code below)

Here is the problem. After prepare_qat with default config changes submodules to ones with fake quantization and hooks are disappeared. Is it possible to prevent hooks from disappearing during prepare_qat (and submodule.fuse() too)

There is an intermediate code:

...
print('pre qat: ', model.backbone.bottom_up.blocks[2][3].conv_pwl)
print('pre qat: ', model.backbone.bottom_up.blocks[2][3].conv_pwl._forward_pre_hooks.values())
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
torch.quantization.prepare_qat(model, inplace=True)
print('post qat: ', model.backbone.bottom_up.blocks[2][3].conv_pwl)
print('post qat: ', model.backbone.bottom_up.blocks[2][3].conv_pwl._forward_pre_hooks.values())
...

And the output is:

pre qat:  Conv2d(120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)                                                                                                                                    
pre qat:  odict_values([functools.partial(<bound method FeatureHooks._collect_output_hook of <timm.models.feature_hooks.FeatureHooks object at 0x7fd3b5bcf5d0>>, 'blocks.2.3.conv_pwl')])
post qat:  Conv2d(
  120, 40, kernel_size=(1, 1), stride=(1, 1), bias=False
  (activation_post_process): FakeQuantize(
    fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.]), zero_point=tensor([0])
    (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
  (weight_fake_quant): FakeQuantize(
    fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.]), zero_point=tensor([0])
    (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
  )
)
post qat:  odict_values([])

As I can to prevent hooks from disappearing it is needed to put relevant code somewhere here (during prepare_qat):

But during fusion it is not that obvious because we map up to three modules to one and each of them could have hooks. I think it is possible to work around with torch.quantization.fuse_modules(...fuser_func=<func>...)

Yeah feel free to submit a PR for preserving the pre hooks in swapping.

For fusion I think we probably need to error out if you have a prehook for intermediate moduels like BatchNorm because when we fuse batchnorm into conv, batchnorm is gone.

I’ll try!

So Jerry could you explain please for what purpose that was done?

After prepare() is called convert() will remove all hooks.
I think there’s a reason to create such hook than to remove it should we preserve all pre forward / post forward hooks except this one?

Yeah I think we should preserve it, but we need to consider each case carefully, since this is interfering with the quantization. That is, when do we run pre forward hook and post forward hook, do we run it before/after observation/quantization?

Hello, again @jerryzh168,

Have a look at this example of pre forward hook run. Here’s EfficientNet implementation that assumes it can be integrated as backbone to FPN (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/efficientnet.py#L471)
So here we can see that during forward blocks are called sequentially than we collect input features from specific layers.

Assume that we have block1 who outputs some data, block2 who has pre forward hook and directly obtains data from first one and block3 who waits for the same data but obtains it using hook of second one.

In this particular example (EfficientNet implementation) during preparartion pre forward hook on block2 should be called after observation and therefore after quantization (because we collected statistics for that). If block2 happens to be the first in a row who work with quantized data it is very likely that block3 works with it too. Anyway we can place dequant before block3.
As for post forward hooks we do not modify input of the module so run them after observation and quantization

Please leave your thoughts when you have time

I think this requires that the hooks will do some meaningful computation for both quantized and unquantized data. But as long as we define the semantics clearly it should be OK. So after quantization, the for pre_forward hooks will work with quantized input from previous later, and the forward hooks will work with quantized output from current later, right?

actaully we use hooks to implement observe and fake quantize as well, so please make sure that works.

Yes, I think so.
pre_forward hooks work with quantized data from previous layer, forward hooks work with quantized output from current layer

You are right, we should check whether we are trying to preserve observer or right hook.
In my PR I have handled that case. Without it provided test set fails, with it works well.

I think I should extended test set to test new functionality. Would you mind to guide me in that?

Also anticipating possible questions

I’ve introduced changes to fuse_modules.py

I propose that it is needed to preserve pre and post hooks on fused modules, where it possible. While it hard to define how to preserve hooks for second and third module in sequence (because input data changes and three modules are collapsing into atomic one). But we can easily preserve pre_forward hook of base module.

What cases can we process also?

Can you post the PR?

Here it is