Custom quantization-aware-training with hooks

Hi, I wish to implement a custom quantization aware training. There are two ways to do this

Consider modules B followed by module A. Say I wish to only quantize A_act

[A] ----[A_act]----- [B]--output-loss

Approach-1 ->

  • quantize A_act in forward pass (A_act_quant)
  • then compute the loss and gradient of output wrt loss
  • While backpropagating modify weights in module B using A_act_quant
  • set A_act_quant)_grad = A_act_quant)_grad

Approach-2 ->

  • quantize A_act in forward pass (A_act_quant)
  • then compute the loss and gradient of output wrt loss
  • While backpropagating modify weights in module B using A_act
  • set A_act_quant)_grad = A_act_quant)_grad

For implementing approach-2, I am using torch.autograd.Function and quantizing in forward function and copying gradients in the first approach.

class Binarizer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (input >= 0).to(torch.FloatTensor)#, 1

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input

A_act_quant = Binarizer.apply(A_act)

Can you explain how approach-1 can be implemented.



I think the following should be the way to implement approach-1 but am not sure (modify only the forward pass with a hook). If I modify forward pass output on a module how does backpropagation happen?

class QuantizeHookIdentity(nn.Module):
    def __init__(self,):
        pass
    def forward(self, x):
        return x
quantmodule = QuantizeHookIdentity()
quantmodule.register_forward_hook(lambda module,input,output: (output>=0).to(torch.FloatTensor))
A_act_quant = quantmodule(A_act)