Custom GRU module for Quantization Aware Training

I really need to be able to do quantization aware training on GRU layers and PyTorch doesn’t support it yet. However, it seems to support static quantization for LSTM layers through custom modules. So, I used what was done in torch/ao/nn/quantizable/modules/rnn.py to make a quantizable version of the GRU layers.

  1. Am I on the right track by following this? From what I read, the only missing part at that point is the from_observed( ) method that can be found in torch/ao/nn/quantized/modules/rnn.py

  2. What’s the use of calling torch.ao.quantization.prepare( ) in the from_float( ) method and torch.ao.quantization.convert( ) in the from_observed( ) method? Isn’t that done when calling prepare_qat( ) or convert( ) in the main program for example?

  3. In the quantization doc, it says “Currently, there is a requirement that ObservedCustomModule will have a single Tensor output, and an observer will be added by the framework (not by the user) on that output. The observer will be stored under the activation_post_process key as an attribute of the custom module instance. Relaxing these restrictions may be done at a future time”. Does it means that the forward method can’t output the output features AND the final hidden state?

Thanks a lot and have a nice day!

  1. Yes, you can modify the quantizable LSTM and implement a similar version for GRU to support static quantization for GRU using custom module API
  2. we are calling prepare/convert in from_float and from_observed because we are relying eager mode quantization to prepare and convert the submodules of the quantizable custom module here
  3. I think the doc is referring to the support in fx graph mode quantization, and quantizable LSTM is actually supported there as well now (pytorch/test/quantization/fx/test_quantize_fx.py at main · pytorch/pytorch · GitHub), if you are using eager mode APIs (e.g. prepare_qat), this shouldn’t be an issue I think.