Question about prepare_fx

In practice, if I have a nn.Conv2d after a nn.Conv2d, the prepare_fx method would insert an observer between these two modules.

But, if I have a nn.BatchNorm2d after a nn.Conv2d (NO FUSING), the prepare_fx would not insert an observer, which makes convert_fx wouldn’t convert the referenced conv2d module to quantized conv2d module.

hi, can you share a simple repro here exhibiting the problem? how are you disabling the fusion of Conv+BN in the prepare step?

Sure, basically I just comment the fusing part in prepare_fx, but you can reprod it like below:

class MyModule(nn.Module):
  def __init__(self):
    self.conv1 = nn.Conv2d(3,16,3,1,1)
    self.bn1 = nn.BatchNorm2d(16)
    self.relu = nn.ReLU()
    self.check = nn.Conv2d(16, 32, 3, 1, 1)

  def forward(self, x):
    x = self.check(self.relu(self.bn1(self.conv1(x))))

and use:

model = MyModule()
qconfig_mapping = QConfigMapping().set_global(get_default_qconfig("x86"))
prepared = prepare_fx(model, qconfig_mapping)

The conv1bn1relu would be fused to a single conv2d module, and you would see that one observer is inserted between this conv2d and check module.
But, if we replace self.check to nn.BatchNorm2d(16), the observer would not be inserted.

When I run this snippet locally, I see the observers inserted after conv1 and before self.check.
My test code

      class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv2d(3,16,3,1,1)
                self.bn1 = nn.BatchNorm2d(16)
                self.relu = nn.ReLU()
                self.check = nn.BatchNorm2d(16) #nn.Conv2d(16, 32, 3, 1, 1)

            def forward(self, x):
                x = self.check(self.relu(self.bn1(self.conv1(x))))
                return x

        model = M().eval()
        mp = prepare_fx(model, get_default_qconfig_mapping(), example_inputs=torch.randn(1, 3, 1, 1))
        print(mp)
        mc = convert_fx(mp)
        print(mc)

Here is the output from prepare

GraphModule(
  (activation_post_process_0): HistogramObserver(min_val=inf, max_val=-inf)
  (conv1): ConvReLU2d(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (activation_post_process_1): HistogramObserver(min_val=inf, max_val=-inf)
  (check): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activation_post_process_2): HistogramObserver(min_val=inf, max_val=-inf)
)



def forward(self, x):
    activation_post_process_0 = self.activation_post_process_0(x);  x = None
    conv1 = self.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(conv1);  conv1 = None
    check = self.check(activation_post_process_1);  activation_post_process_1 = None
    activation_post_process_2 = self.activation_post_process_2(check);  check = None
    return activation_post_process_2

output after convert

GraphModule(
  (conv1): QuantizedConvReLU2d(3, 16, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
  (check): QuantizedBatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)



def forward(self, x):
    conv1_input_scale_0 = self.conv1_input_scale_0
    conv1_input_zero_point_0 = self.conv1_input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, conv1_input_scale_0, conv1_input_zero_point_0, torch.quint8);  x = conv1_input_scale_0 = conv1_input_zero_point_0 = None
    conv1 = self.conv1(quantize_per_tensor);  quantize_per_tensor = None
    check = self.check(conv1);  conv1 = None
    dequantize_2 = check.dequantize();  check = None
    return dequantize_2

My bad, I shouldn’t add self.relu, so here’s what I got:

class M(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3,16,3,1,1)
    self.bn1 = nn.BatchNorm2d(16)
    self.check = nn.BatchNorm2d(16) #nn.Conv2d(16, 32, 3, 1, 1)

  def forward(self, x):
    x = self.check((self.bn1(self.conv1(x)))
    return x

and after prepare:

GraphModule(
  (activation_post_process_0): HistogramObserver(min_val=inf, max_val=-inf)
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (check): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activation_post_process_1): HistogramObserver(min_val=inf, max_val=-inf)
)

after convert:

GraphModule(
  (conv1): QuantizedConv2d(Reference)(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (check): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

I think the self.conv1 already converted into QuantizedConv2d.

conv1 is still a Reference convolution which not takes uint8 dtype tensor input, and check is still a float point version of BatchNorm2d.

I’d like to have a QuantizedConv2d and a QuantizedBatchNorm2d module like @supriyar comment above.

this looks like a bug, can you file an issue on Issues · pytorch/pytorch · GitHub with the quantization label? Thanks!

Sure, I’ll open one, btw is there any way that we can walkaround this potential bug?

This isn’t working because you’re using default settings for something that is far from default. In general, Conv + multiple Batchnorms should be the same as Conv + 1 Batchnorm which is matematically equivalent to just a Conv op with a different weight and thats what we’re handling by default, fusing conv-bn → conv. The issue is its trying to do like a fused quantization with conv-bn but the code assumes that they’ve been fused together and that there’s only a conv but because of the multiple batchnorms, only 1 gets fused and the rest cause a conflict.

If you really needed to get around this the easiest way is that you could either insert a torch.nn.Identity between the 2 batchnorms in your model or you could make a new backend_config that works with the above situation, which would be equivalent to just commenting out these lines from the default backend config