When I run this snippet locally, I see the observers inserted after conv1 and before self.check
.
My test code
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3,16,3,1,1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
self.check = nn.BatchNorm2d(16) #nn.Conv2d(16, 32, 3, 1, 1)
def forward(self, x):
x = self.check(self.relu(self.bn1(self.conv1(x))))
return x
model = M().eval()
mp = prepare_fx(model, get_default_qconfig_mapping(), example_inputs=torch.randn(1, 3, 1, 1))
print(mp)
mc = convert_fx(mp)
print(mc)
Here is the output from prepare
GraphModule(
(activation_post_process_0): HistogramObserver(min_val=inf, max_val=-inf)
(conv1): ConvReLU2d(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
(activation_post_process_1): HistogramObserver(min_val=inf, max_val=-inf)
(check): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(activation_post_process_2): HistogramObserver(min_val=inf, max_val=-inf)
)
def forward(self, x):
activation_post_process_0 = self.activation_post_process_0(x); x = None
conv1 = self.conv1(activation_post_process_0); activation_post_process_0 = None
activation_post_process_1 = self.activation_post_process_1(conv1); conv1 = None
check = self.check(activation_post_process_1); activation_post_process_1 = None
activation_post_process_2 = self.activation_post_process_2(check); check = None
return activation_post_process_2
output after convert
GraphModule(
(conv1): QuantizedConvReLU2d(3, 16, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
(check): QuantizedBatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
def forward(self, x):
conv1_input_scale_0 = self.conv1_input_scale_0
conv1_input_zero_point_0 = self.conv1_input_zero_point_0
quantize_per_tensor = torch.quantize_per_tensor(x, conv1_input_scale_0, conv1_input_zero_point_0, torch.quint8); x = conv1_input_scale_0 = conv1_input_zero_point_0 = None
conv1 = self.conv1(quantize_per_tensor); quantize_per_tensor = None
check = self.check(conv1); conv1 = None
dequantize_2 = check.dequantize(); check = None
return dequantize_2