Register_backward_hook in quantized model

I want to change the gradient (STE) of the quantized model so that it supports back propagation, but it fails. Can someone tell me how to use the register_backward_hook () function or other methods to realize this function in the quantized model?

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.f1 = nn.Linear(4, 1, bias=True)
        self.f2 = MyMean()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.weight_init()
        self.quant.register_backward_hook(self.my_hook)

    def forward(self, input):
        self.input = input
        output = self.quant(input)
        output = self.f1(output)  # 先进行运算1,后进行运算2
        output = self.quant(self.f2(self.dequant(output)))
        output = self.dequant(output)
        return output

    def weight_init(self):
        self.f1.weight.data.fill_(8.0)  # 这里设置Linear的权重为8
        self.f1.bias.data.fill_(2.0)  # 这里设置Linear的bias为2

    def my_hook(self, module, grad_input, grad_output):
        print('doing my_hook')
        print('original grad:', grad_input)
        print('original outgrad:', grad_output)
        # grad_input = grad_input[0]*self.input   
        # grad_input = tuple([grad_input])    
        # print('now grad:', grad_input)
        return grad_output

firstly, unrelated to your question, the model above isn’t going to work well for quantization. You have self.quant used at multiple points in the forward, which means tha the quantization flow will only be able to assign a single set of quantization parameters that will need to be used in 2 places, drastically lowering accuracy.

as for your question, the issue is that there are no weight tensors in quantized modules. Those tensors get packed into a special format that the quantized kernel can utilize more effectively, so there’s nothing to do backprop on.

Generally the way something like this is done is by using fake quants i.e. modules that simulate quantized numerics with fp32 dtypes. Once training/whatever is complete, the model would then be converted to the quantized model.

see Quantization — PyTorch 1.13 documentation for more info

Thank you for your reply! However, My requirement is to make the quantified model support back propagation. Therefore, the fake quants will not work. How can I get the weight tensors in quantized modules?

Generally module._weight() but depends on the module

Is there any solution to support backpropagation in the model has been quantized?

no, because in the final quantized model, the weights are compressed so that they are ready to be used for production, so its not in an intractable format. Its like asking ‘how can I edit a text file after I’ve compressed it into a zip’. You can’t, unless you want to repeatedly decompress it, make an edit, then recompress it. We don’t have any support for something like that since it’d be faster to just do the edits first then compress it after.

1 Like

What should I do if I want to decompress it? Are there any resources that can help me?

You’d use the command above which unpacks the quantized tensor.

Realistically you’d be better off writing a custom quantized linear/conv op which stores the quantized tensor normally, then packs it and uses the normal kernel when forward is called. You could then write the autograd function for this op.

1 Like