Quantized Batch Norm operation

I have a quantized model with Batch Norm and would like to know what is the operation being done here that transforms the input into output

The code that I am using is

import numpy as np
import torch
import torch.nn as nn
import torch.quantization
from custom_convolve import convolve_torch, convolve_numpy
torch.set_printoptions(precision=30)
np.set_printoptions(precision=30)
torch.manual_seed(123)

class M_quant_fullweight(nn.Module):

    def __init__(self):
            super(M_quant_fullweight, self).__init__()
            # QuantStub converts tensors from floating point to quantized
            self.quant = torch.quantization.QuantStub()
            self.conv1 = nn.Sequential(nn.BatchNorm2d(3))
            self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):

        x = self.quant(x)
        x = self.conv1(x)    
        x = self.dequant(x)
         
        return x

input_fp32 = torch.rand(1, 3, 3, 3)

model_quant = M_quant_fullweight()
model_quant.eval()

model_quant.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model_quant, inplace=True)
torch.backends.quantized.engine = 'fbgemm'
torch.quantization.convert(model_quant, inplace=True)

print(model_quant)

checkpoint = torch.load('BatchNorm_layer.pt', map_location=torch.device('cpu'))
model_quant.load_state_dict(checkpoint['state_dict'])

activations = []
def custom_hook(module, input, output):
    info = {
        'module': module,
        'input': input,
        'output': output
    }
    activations.append(info)

for name, module in model_quant.named_modules():
    if len(list(module.children())) == 0:
        print(name, module)
        module.register_forward_hook(custom_hook)

output_quant = model_quant(input_fp32)

print("Running Mean ", model_quant.conv1[0].running_mean)
print("Running Variance ", model_quant.conv1[0].running_var)
print("Batch Norm Scale ", model_quant.conv1[0].scale)
print("Batch Norm Zero Point", model_quant.conv1[0].zero_point)
print("Batch Norm Eps ", model_quant.conv1[0].eps)

print("Input to batch norm ", activations[1]['input'][0][0][0][0][0])
print("Output of batch norm ", activations[1]['output'][0][0][0][0])

The weight can be downloaded here

One of the input to the QuantizedBatchNorm is

tensor(0.299066513776779174804687500000, size=(), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.007870171219110489,
       zero_point=0)

and the corresponding output is

tensor(-110.886795043945312500000000000000, size=(), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=5.280323505401611,
       zero_point=73)

with batchnorm parameters

Running Mean  tensor([0.565347611904144287109375000000, 0.431774944067001342773437500000,
        0.367271929979324340820312500000])
Running Variance  tensor([0.079505279660224914550781250000, 0.062080144882202148437500000000,
        0.053845494985580444335937500000])
Batch Norm Scale  tensor([5.280323505401611328125000000000])
Batch Norm Zero Point tensor([73])

I am aware that I need to subtract the first mean and divide by first variance to do normalization as well as the quantization operation is rounding off after dividing by scale and adding zero point. But I am not able to manually compute it. Any help would be appreciated.

sorry for the late reply, are you asking about what is the computation done by batchnorm?