hi
i want to quantize my model in variable precision, but pytorch only support 8bit quantization, so i create custom fake qunatized module. i’m success quantize conv-relu layer by use my fake qunatized module, but accuracy reduace in quantize conv-bn-relu layer.
In pytorch, when i want quantize(QAT) conv-bn-relu layer, first, i need to fused module to convrelu2d
layer, then i can training this layer. In the finish, i can use .int_repr()
to get weight, .q_scale
to get scale fator. and i follow this flow to design my fake qunatized module.
first, i fused conv weight bias with bn mean var beta gamma by
new_weight = conv_weight * bn_beta/(sqrt(bn_var + bn_eps))
new_bias = (conv_bias - bn_mean)/(sqrt(bn_var + bn_eps)) * beta + gamma
then, new weight and bias assgin to another conv layer and training this layer, but accuracy reduce after training.
so, i want to know how convrelu2d
layer actually work in QAT. how to calculation weight_scale, weight_zero_point.
this is my code(fused conv weight bias with bn mean var beta gamma):
for name, param in fused_model.named_children():
if hasattr(param, 'conv'):
print(name)
#model batchnorm parameter
model_bn = model.model_list()[name][1]
model_bn_mean = model_bn.running_mean
model_bn_var_sqrt = torch.sqrt(model_bn.running_var + model_bn.eps)
model_bn_beta = model_bn.weight
model_bn_gamma = model_bn.bias
#model conv parameter
model_conv = model.model_list()[name][0]
model_conv_weight = model_conv.weight
if model_conv.bias is not None:
model_conv_bias = model_conv.bias
else:
model_conv_bias = model_bn_mean.new_zeros(model_bn_mean.shape)
#new parameter
w1 = (model_bn_beta/model_bn_var_sqrt)
w1 = w1.reshape([model_conv.out_channels, 1, 1, 1])
w = model_conv_weight * w1
b = (model_conv_bias - model_bn_mean)/model_bn_var_sqrt * model_bn_beta + model_bn_gamma
param.conv.weight = torch.nn.Parameter(w)
param.conv.bias = torch.nn.Parameter(b)
fake quantize:
def quant_dequant(self, x, q_min, q_max, data_type, name):
with torch.no_grad():
if(data_type == 'act'):
if(self.act_min[name] == None):
x_min = torch.min(x)
else:
x_min = 0.99*self.act_min[name] + 0.01*torch.min(x)
self.act_min[name] = x_min
if(self.act_max[name] == None):
x_max = torch.max(x)
else:
x_max = 0.99*self.act_max[name] + 0.01*torch.max(x)
self.act_max[name] = x_max
scale = (x_max - x_min)/(q_max - q_min)
self.act_scale[name] = scale.cpu()
else:
min_data, min_index = torch.min(x, dim = 3)
min_data, min_index = torch.min(min_data, dim = 2)
min_data, min_index = torch.min(min_data, dim = 1)
max_data, max_index = torch.max(x, dim = 3)
max_data, max_index = torch.max(max_data, dim = 2)
max_data, max_index = torch.max(max_data, dim = 1)
if(self.weight_min[name] == None):
x_min = min_data
else:
x_min = 0.99*self.weight_min[name] + 0.01*min_data
self.weight_min[name] = x_min
if(self.weight_max[name] == None):
x_max = max_data
else:
x_max = 0.99*self.weight_max[name] + 0.01*max_data
self.weight_max[name] = x_max
scale = 2*torch.maximum(x_max, torch.abs(x_min))/(q_max - q_min)
scale = scale.unsqueeze(1).unsqueeze(2).unsqueeze(3)
zero = 0
self.weight_scale[name] = scale.squeeze(3).squeeze(2).squeeze(1).cpu()
#quantized
#print(x.shape)
#print(scale.shape)
q_x = torch.round(x/scale.cuda() + zero)
q_x = torch.clamp(q_x, min = q_min, max = q_max)
#de-quantized
x = torch.tensor(q_x, dtype = torch.int32)
x = (x - zero) * scale
x = torch.tensor(x, dtype = torch.float32)
return x.cuda()
thank you