Custom QAT using ao.nn.qat modules, is this a valid approach?

Hello,

Since the model I’m using has a GRU, which is not QAT ready in torch yet, and since I need to do PQT later in TfLite, I’ve decided to try to do QAT using ao.nn.qat modules while modifying them to add activation FakeQuant without going through traditional (buggy?) route.

An extract of how I’ve implemented the module:

class GRU_qat(nn.GRU):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        bias: bool = True,
        batch_first: bool = False,
        dropout: float = 0.0,
        bidirectional: bool = False,
        qconfig=None,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, **factory_kwargs)
        assert qconfig, "qconfig must be provided for QAT module"
        self.qconfig = qconfig
        self.activation_fake_quant = qconfig.activation(factory_kwargs=factory_kwargs)
        self.weight_fake_quant_1 = qconfig.weight(factory_kwargs=factory_kwargs)
        self.weight_fake_quant_2 = qconfig.weight(factory_kwargs=factory_kwargs)
    
    def forward(self, input, hx=None):
...
            result = _VF.gru(self.activation_fake_quant(input), batch_sizes, hx, [self.weight_fake_quant_1(self._flat_weights[0]), self.weight_fake_quant_2(self._flat_weights[1]),
                                          self._flat_weights[2], self._flat_weights[3]], self.bias,
                             self.num_layers, self.dropout, self.training, self.bidirectional)
...

The process I followed was simple: Load the pretrained weights to the original model with its QAT equivalent modules then train normally at a low learning rate for a few epochs on the training dataset. I’ve tested both weight only (without activation fake quant) and full QAT. I didn’t see much of an improvement but I already have strong quantized performance.

Now my question is: is my method valid? Backpropagation didn’t give me an error so I assumed torch handled the fake quantization well, but did it really? I’m aware that I ignored the case of layer fusion, batch norm and math ops. Are their impact great? I assumed the weight fake quant will be the most important.

Thank you!