Changing Quantized Weights

I have a post-training statically quantized NN. I want to change a couple of weight values of one of the convolution layer before the inference. The weight change should be based on int8 values and not on the save-format (which is torch.qint8 with corresponding scales and zero points). So far I have done the following:

# instantiate the quantized net (not shown here).

# get one of the conv layers
tmp = model_int8.state_dict()['features.0.weight']

# get int repr
tmp_int8 = tmp.int_repr()

# change value (value change is dependent on the int8 value)
tmp_int8[0][0][0][0] = new_value

# TODO: how to convert tmp_int8 to torch.qint8 type?
new_tmp = convert_int8_to_qint8(tmp_int8) # how to do this

# TODO: based on the above step:
model_int8.state_dict()['features.0.weight'] = new_tmp

My question is how to change the int8 tensor to torch.qint8 based on the scales and zero_points of the original weight tensor (something similar to torch.quantized_per_channel() but for int8 to qint8)? OR is there another way to do this?

Thank you.

if this is a per channel quantized tensor? you can call pytorch/native_functions.yaml at master · pytorch/pytorch · GitHub to assemble a per channel quantized tensor with int_repr, scales and zero_points etc.
however, this api might be deprecated in future pytorch releases

1 Like

@jerryzh168 thanks for the reply. yes, it is per channel quantization.

So, I can now generate torch.qint8 tensor from tmp_int8 tensor.
I verified by printing the new_tmp tensor to see the new values are changed.

However, the following line is not updating the model weights:

model_int8.state_dict()['features.0.weight'] = new_tmp

When I print the model_int8.state_dict()['features.0.weight'] it still shows the old values. How can I fix this?

Thank you.

I think you probably need to do model_int8.features[0].weight = new_tmp

1 Like

Hi @jerryzh168 , thank you, it worked.

Can you please tell me the difference between model_int8.state_dict()['features.0.weight'] & model_int8.features[0].weight?

Because after changing the code you suggested above, when I tried printing with model_int8.features[0].weight it shows updated values. But model_int8.state_dict()['features.0.weight'] shows old vales.

oh really? model_int8.state_dict() is probably a read only copy and changing that won’t change the original weight. I’m not sure why model_int8.state_dict() is not updated after you modify the weight though, that is not expected I think, are you sure that it did not change?

I reran to confirm it just now and model_int8.features[0].weight shows updated values, but model_int8.state_dict()['features.0.weight'] shows old vales.

Do I need to save and reload the model after a manual weight update – I do not see why?
Also, in the above case of mismatch, which weight values will be used for inference: model_int8.features[0].weight or model_int8.state_dict()['features.0.weight']?

@jerryzh168 Hey, My problem persists. Please check: Changed quantized weights not reflecting in state_dict()

Thanks for help.