Custom fake qunatized module with variable pricision

hi
i want to quantize my model in variable precision, but pytorch only support 8bit quantization, so i create custom fake qunatized module. i’m success quantize conv-relu layer by use my fake qunatized module, but accuracy reduace in quantize conv-bn-relu layer.

In pytorch, when i want quantize(QAT) conv-bn-relu layer, first, i need to fused module to convrelu2d layer, then i can training this layer. In the finish, i can use .int_repr() to get weight, .q_scale to get scale fator. and i follow this flow to design my fake qunatized module.

first, i fused conv weight bias with bn mean var beta gamma by

new_weight = conv_weight * bn_beta/(sqrt(bn_var + bn_eps))
new_bias = (conv_bias - bn_mean)/(sqrt(bn_var + bn_eps)) * beta + gamma

then, new weight and bias assgin to another conv layer and training this layer, but accuracy reduce after training.

so, i want to know how convrelu2d layer actually work in QAT. how to calculation weight_scale, weight_zero_point.

this is my code(fused conv weight bias with bn mean var beta gamma):

for name, param in fused_model.named_children():
    if hasattr(param, 'conv'):
        print(name)
        #model batchnorm parameter
        model_bn = model.model_list()[name][1]
        
        model_bn_mean     = model_bn.running_mean
        model_bn_var_sqrt = torch.sqrt(model_bn.running_var + model_bn.eps)
        model_bn_beta     = model_bn.weight
        model_bn_gamma    = model_bn.bias
    
        #model conv parameter
        model_conv = model.model_list()[name][0]
        
        model_conv_weight = model_conv.weight
        if model_conv.bias is not None:
            model_conv_bias = model_conv.bias
        else:
            model_conv_bias = model_bn_mean.new_zeros(model_bn_mean.shape)
        
        
        #new parameter
        w1 = (model_bn_beta/model_bn_var_sqrt)
        w1 = w1.reshape([model_conv.out_channels, 1, 1, 1])
        w = model_conv_weight * w1
        
        b = (model_conv_bias - model_bn_mean)/model_bn_var_sqrt * model_bn_beta + model_bn_gamma
        
        param.conv.weight = torch.nn.Parameter(w)
        param.conv.bias   = torch.nn.Parameter(b)

fake quantize:

 def quant_dequant(self, x, q_min, q_max, data_type, name):

    with torch.no_grad():
        if(data_type == 'act'):
            if(self.act_min[name] == None):
                x_min = torch.min(x)
            else:
                x_min = 0.99*self.act_min[name] + 0.01*torch.min(x)
            self.act_min[name] = x_min   
             
            
            if(self.act_max[name] == None):
                x_max = torch.max(x)
            else:
                x_max = 0.99*self.act_max[name] + 0.01*torch.max(x)
            self.act_max[name] = x_max
            
            scale = (x_max - x_min)/(q_max - q_min)
            self.act_scale[name] = scale.cpu()
            
        else:
            min_data, min_index = torch.min(x,        dim = 3)
            min_data, min_index = torch.min(min_data, dim = 2)
            min_data, min_index = torch.min(min_data, dim = 1)
            
            max_data, max_index = torch.max(x,        dim = 3)
            max_data, max_index = torch.max(max_data, dim = 2)
            max_data, max_index = torch.max(max_data, dim = 1)
            
            if(self.weight_min[name] == None):
                x_min = min_data
            else:
                x_min = 0.99*self.weight_min[name] + 0.01*min_data
            self.weight_min[name] = x_min   
             
            
            if(self.weight_max[name] == None):
                x_max = max_data
            else:
                x_max = 0.99*self.weight_max[name] + 0.01*max_data
            self.weight_max[name] = x_max
        
            scale = 2*torch.maximum(x_max, torch.abs(x_min))/(q_max - q_min)
            scale = scale.unsqueeze(1).unsqueeze(2).unsqueeze(3)
            zero = 0
    
            self.weight_scale[name] = scale.squeeze(3).squeeze(2).squeeze(1).cpu()
                      
        #quantized
        #print(x.shape)
        #print(scale.shape)
        q_x = torch.round(x/scale.cuda() + zero)
        q_x = torch.clamp(q_x, min = q_min, max = q_max)
        #de-quantized
        x = torch.tensor(q_x, dtype = torch.int32)
        x = (x - zero) * scale
        x = torch.tensor(x, dtype = torch.float32)
        
    return x.cuda()  

thank you